diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index cf69b62f7d4f..766b97e253b7 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -252,6 +252,7 @@ class ZeROOptimizer(DeepSpeedOptimizer): def __init__(self): self._backward_hook_state = BackwardHookStateManager() + self.retain_graph_on_current_backward = False # Delegate backward hook state management to the manager. # These properties provide backward compatibility with code that accesses @@ -419,10 +420,14 @@ def backward(self, loss, **kwargs): scaled_loss = self.backward_prologue(loss) retain_graph = kwargs.pop('retain_graph', False) + self.retain_graph_on_current_backward = retain_graph self.enter_backward() - scaled_loss.backward(retain_graph=retain_graph) - self.backward_epilogue() - self.exit_backward() + try: + scaled_loss.backward(retain_graph=retain_graph) + self.backward_epilogue() + finally: + self.exit_backward() + self.retain_graph_on_current_backward = False def register_grad_acc_post_hook(self, hook): """Register a callback to run when all gradient hooks have fired.""" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 38770a33afdf..d9d71d5dbd75 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2754,6 +2754,7 @@ def _backward_epilogue(self): if not bf16_optimizer: self.optimizer.backward_epilogue() self.optimizer.exit_backward() + self.optimizer.retain_graph_on_current_backward = False if self.is_deepcompile_active(): deepcompile_backward_epilogue() @@ -2915,7 +2916,7 @@ def _flush_coalesced_reduction_zero3(self, optimizer): optimizer.reduce_ready_partitions_and_remove_grads(param) optimizer.independent_gradient_partition_epilogue() - def scale(self, loss): + def scale(self, loss, retain_graph=False): r"""Apply loss scaler for manual backward pass. Use this method when calling loss.backward() directly instead of engine.backward(). @@ -2933,6 +2934,8 @@ def scale(self, loss): Arguments: loss: Scalar loss tensor to be scaled + retain_graph: bool, default: false + forward on user defined choice of retain_graph Returns: Scaled loss tensor ready for .backward() call @@ -2958,6 +2961,8 @@ def scale(self, loss): # Apply loss scaler based on optimizer type scaled_loss = loss if isinstance(self.optimizer, ZeROOptimizer): + scaled_loss = self.optimizer.scale_if_loss(loss) + self.optimizer.retain_graph_on_current_backward = retain_graph scaled_loss = self.optimizer.scale_if_loss(scaled_loss) elif self.torch_autocast_z0_gradscaler: scaled_loss = self.torch_autocast_z0_gradscaler.scale(scaled_loss) @@ -2998,6 +3003,7 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): # TODO: handle these scaling with direct calls to loss.backward() if isinstance(self.optimizer, ZeROOptimizer): loss = self.optimizer.scale_if_loss(loss) + self.optimizer.retain_graph_on_current_backward = retain_graph elif self.torch_autocast_z0_gradscaler: loss = self.torch_autocast_z0_gradscaler.scale(loss) @@ -3015,6 +3021,8 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): self._backward_epilogue() self._running_engine_backward = False + if isinstance(self.optimizer, ZeROOptimizer): + self.optimizer.retain_graph_on_current_backward = False return gas_scaled_loss diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 161b3e27e440..31e29a775e4c 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -136,6 +136,7 @@ def __init__( zero_quantized_nontrainable_weights=False, zero_module_granularity_threshold=0, log_trace_cache_warnings=False, + retain_graph_checker=None, ): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=False) @@ -153,6 +154,7 @@ def __init__( self.zero_quantized_weights = zero_quantized_weights self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights self.log_trace_cache_warnings = log_trace_cache_warnings + self.retain_graph_checker = retain_graph_checker if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: self.offload_device = offload_param_config.device @@ -563,7 +565,11 @@ def post_sub_module_backward_function(self, sub_module): for param in params_to_fetch: param.data = param.data.t() if len(param.ds_shape) != 1 else param.data - self.get_param_coordinator().release_sub_module(sub_module, forward=False) + # Keep gathered params alive when the current backward retains the graph, + # so a second backward over the same forward can reuse valid saved tensors. + retain_graph_backward = bool(self.retain_graph_checker()) if self.retain_graph_checker is not None else False + if not retain_graph_backward: + self.get_param_coordinator().release_sub_module(sub_module, forward=False) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release", diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e0856570162c..d7c03fe949cf 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -285,6 +285,7 @@ def __init__( zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, zero_module_granularity_threshold=zero_module_granularity_threshold, log_trace_cache_warnings=log_trace_cache_warnings, + retain_graph_checker=lambda: self.retain_graph_on_current_backward, ) self.persistent_parameters = self.parameter_offload.persistent_parameters @@ -571,6 +572,7 @@ def initialize_ds_offload( zero_quantized_nontrainable_weights, zero_module_granularity_threshold, log_trace_cache_warnings, + retain_graph_checker=None, ): return DeepSpeedZeRoOffload(module=module, timers=timers, @@ -589,7 +591,8 @@ def initialize_ds_offload( zero_quantized_weights=zero_quantized_weights, zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights, zero_module_granularity_threshold=zero_module_granularity_threshold, - log_trace_cache_warnings=log_trace_cache_warnings) + log_trace_cache_warnings=log_trace_cache_warnings, + retain_graph_checker=retain_graph_checker) def _get_param_partition_group(self, param): return getattr(param, "ds_process_group", self.dp_process_group) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index b094c22c2fb5..55f2bcb946ad 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -539,6 +539,128 @@ def test_separate_loss_function(self, zero_stage): model_engine.destroy() + def test_two_losses_separate_backward_gas1(self, zero_stage): + """Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7352 + + A single forward followed by two separate backward passes, with + zero_grad() in between, must produce independent gradients for each + loss when gradient_accumulation_steps == 1. Previously the first + backward was treated as the accumulation boundary, which froze loss1's + gradients so that zero_grad() had no effect and loss2's gradients were + doubled (grad2 == grad1 + grad2). Each DeepSpeed gradient set is + compared against an equivalent PyTorch DDP baseline. + """ + hidden_dim = 4 + batch_size = 2 + + # Default gradient_accumulation_steps=1, which is the failing case. + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim) + + loss_fn = torch.nn.CrossEntropyLoss() + + # Two different targets so loss1 and loss2 yield distinct gradients; + # this makes accidental accumulation (grad1 + grad2) detectable. + torch.manual_seed(456) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y1 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP baseline: separate backward for each loss with zero_grad in between. + output_ddp = model_ddp(x) + loss1_ddp = loss_fn(output_ddp, y1) + loss2_ddp = loss_fn(output_ddp, y2) + + optimizer_ddp.zero_grad() + loss1_ddp.backward(retain_graph=True) + grads1_ddp = collect_ddp_gradients(model_ddp) + + optimizer_ddp.zero_grad() + loss2_ddp.backward() + grads2_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: identical sequence. + output_ds = model_engine(x) + loss1_ds = loss_fn(output_ds, y1) + loss2_ds = loss_fn(output_ds, y2) + + model_engine.zero_grad() + model_engine.backward(loss1_ds, retain_graph=True) + grads1_ds = collect_gradients_safe(model_engine) + + model_engine.zero_grad() + model_engine.backward(loss2_ds) + grads2_ds = collect_gradients_safe(model_engine) + + # The second backward must NOT accumulate loss1's gradients on top of + # loss2; both gradient sets must match their DDP counterparts. + compare_gradients(grads1_ddp, grads1_ds, "loss1") + compare_gradients(grads2_ddp, grads2_ds, "loss2") + + model_engine.destroy() + + def test_two_losses_separate_manual_backward_gas1(self, zero_stage): + """Regression test for https://github.com/deepspeedai/DeepSpeed/issues/7352 + + Same scenario as test_two_losses_separate_backward_gas1, but using the + torch-style manual path engine.scale(loss, retain_graph=True).backward( + retain_graph=True) instead of engine.backward(loss, retain_graph=True). + The manual path bypasses engine.backward(), so retain_graph must be + propagated through scale(). For ZeRO-3 this defers parameter release so + the retained graph's saved tensors stay valid for the second backward + over the same forward. + """ + hidden_dim = 4 + batch_size = 2 + + # Default gradient_accumulation_steps=1, which is the failing case. + model_ddp, optimizer_ddp, model_engine, device, dtype = setup_models_and_engines(model_class=SimpleOutputModel, + zero_stage=zero_stage, + hidden_dim=hidden_dim) + + loss_fn = torch.nn.CrossEntropyLoss() + + # Two different targets so loss1 and loss2 yield distinct gradients; + # this makes accidental accumulation (grad1 + grad2) detectable. + torch.manual_seed(456) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype) + y1 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + y2 = torch.randint(0, hidden_dim, (batch_size, ), device=device) + + # DDP baseline: separate backward for each loss with zero_grad in between. + output_ddp = model_ddp(x) + loss1_ddp = loss_fn(output_ddp, y1) + loss2_ddp = loss_fn(output_ddp, y2) + + optimizer_ddp.zero_grad() + loss1_ddp.backward(retain_graph=True) + grads1_ddp = collect_ddp_gradients(model_ddp) + + optimizer_ddp.zero_grad() + loss2_ddp.backward() + grads2_ddp = collect_ddp_gradients(model_ddp) + + # DeepSpeed: identical sequence via the manual scale().backward() path. + output_ds = model_engine(x) + loss1_ds = loss_fn(output_ds, y1) + loss2_ds = loss_fn(output_ds, y2) + + model_engine.zero_grad() + model_engine.scale(loss1_ds, retain_graph=True).backward(retain_graph=True) + grads1_ds = collect_gradients_safe(model_engine) + + model_engine.zero_grad() + model_engine.scale(loss2_ds).backward() + grads2_ds = collect_gradients_safe(model_engine) + + # The second backward must NOT accumulate loss1's gradients on top of + # loss2; both gradient sets must match their DDP counterparts. + compare_gradients(grads1_ddp, grads1_ds, "loss1") + compare_gradients(grads2_ddp, grads2_ds, "loss2") + + model_engine.destroy() + class LeafModuleModel(torch.nn.Module): """Model with ModuleList that uses all parameters - for testing leaf module compatibility"""