From 4371ef7806068b4575f7e9ba407381ff99d3e054 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 11:53:39 -0700 Subject: [PATCH 1/2] test: cover zero3 autocast mixed dtype gather Signed-off-by: Masahiro Tanaka --- .../zero/test_zero_autocast_mixed_dtype.py | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py diff --git a/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py b/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py new file mode 100644 index 000000000000..a7238883f856 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py @@ -0,0 +1,129 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.torch_autocast import get_comm_dtype, has_comm_dtype +from deepspeed.runtime.zero.partition_parameters import get_allgather_dtype +from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad +from unit.common import DistributedTest +from unit.util import bf16_required_version_check + + +def _safe_module_name(): + return f"{MixedDtypeAdapterModule.__module__}.{MixedDtypeAdapterModule.__name__}" + + +def _zero3_bf16_autocast_config(): + return { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "bf16": { + "enabled": True, + "bf16_master_weights_and_grads": True, + "bf16_optimizer_states": True, + }, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 0, + "stage3_module_granularity_threshold": 0, + "stage3_use_all_reduce_for_fetch_params": False, + }, + "torch_autocast": { + "enabled": True, + "dtype": str(torch.bfloat16), + "lower_precision_safe_modules": [_safe_module_name()], + }, + } + + +class MixedDtypeAdapterModule(torch.nn.Module): + + def __init__(self, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + self.base_weight = torch.nn.Parameter(torch.randn(hidden_dim, hidden_dim) * 0.01, requires_grad=False) + + def attach_fp32_adapter(self, rank): + device = get_accelerator().current_device_name() + self.adapter_a = torch.nn.Parameter(torch.randn(rank, self.hidden_dim, device=device, dtype=torch.float32) * + 0.01) + self.adapter_b = torch.nn.Parameter(torch.randn(self.hidden_dim, rank, device=device, dtype=torch.float32) * + 0.01) + + assert hasattr(self.base_weight, "convert_to_zero_parameters") + self.base_weight.convert_to_zero_parameters([self.adapter_a, self.adapter_b]) + + def forward(self, x, target): + base = torch.nn.functional.linear(x, self.base_weight) + adapter_hidden = torch.nn.functional.linear(x, self.adapter_a) + adapter = torch.nn.functional.linear(adapter_hidden, self.adapter_b) / self.adapter_a.shape[0] + output = base + adapter + return torch.nn.functional.mse_loss(output.float(), target.float()) + + +def _assert_mixed_partition_dtypes(model): + assert model.base_weight.dtype == torch.bfloat16 + assert model.base_weight.ds_tensor.dtype == torch.bfloat16 + + for adapter_param in [model.adapter_a, model.adapter_b]: + assert adapter_param.dtype == torch.float32 + assert adapter_param.ds_tensor.dtype == torch.float32 + + +def _assert_autocast_comm_dtype(model): + for param in [model.base_weight, model.adapter_a, model.adapter_b]: + assert has_comm_dtype(param) + assert get_comm_dtype(param) == torch.bfloat16 + assert get_allgather_dtype(param, param.ds_tensor) == torch.bfloat16 + + +class TestZero3AutocastMixedDtype(DistributedTest): + world_size = 2 + + def test_fp32_adapter_with_bf16_base_params(self): + if not bf16_required_version_check(): + pytest.skip("BF16 ZeRO-3 autocast test requires BF16 accelerator support.") + + hidden_dim = 8 + config = _zero3_bf16_autocast_config() + + with deepspeed.zero.Init(config_dict_or_path=config): + model = MixedDtypeAdapterModule(hidden_dim) + model.attach_fp32_adapter(rank=4) + _assert_mixed_partition_dtypes(model) + + trainable_params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(trainable_params, lr=0.1) + engine, _, _, _ = deepspeed.initialize(config=config, + model=model, + model_parameters=trainable_params, + optimizer=optimizer) + try: + _assert_mixed_partition_dtypes(engine.module) + _assert_autocast_comm_dtype(engine.module) + + adapter_a_before = safe_get_full_fp32_param(engine.module.adapter_a).detach().clone() + device = engine.device + x = torch.randn(2, hidden_dim, device=device, dtype=torch.float32) + target = torch.randn(2, hidden_dim, device=device, dtype=torch.float32) + + loss = engine(x, target) + engine.backward(loss) + + adapter_a_grad = safe_get_full_grad(engine.module.adapter_a) + assert adapter_a_grad is not None + assert torch.count_nonzero(adapter_a_grad).item() > 0 + + engine.step() + + adapter_a_after = safe_get_full_fp32_param(engine.module.adapter_a) + assert not torch.equal(adapter_a_before, adapter_a_after) + _assert_mixed_partition_dtypes(engine.module) + finally: + engine.destroy() From 527f830b1726946777d5532849f0e41fdd6c9105 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 1 Jul 2026 11:54:06 -0700 Subject: [PATCH 2/2] fix: allow zero3 autocast mixed dtype gather Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/partition_parameters.py | 8 +++++--- tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 9cdff411237c..3f67b60624aa 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1233,9 +1233,11 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allgather_dtype): - # make sure all params have the same dtype - dtype = params[0].dtype # we assume len(params) > 0 - assert all(p.dtype == dtype for p in params), "all params must have the same dtype" + comm_dtypes = [get_allgather_dtype(p, p.ds_tensor) for p in params] + assert all( + dtype == allgather_dtype + for dtype in comm_dtypes), ("all params in a coalesced bucket must have the same communication dtype; " + f"expected {allgather_dtype}, got {comm_dtypes}") partition_sz = sum(p.ds_tensor.ds_numel for p in params) diff --git a/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py b/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py index a7238883f856..aa2981288928 100644 --- a/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py +++ b/tests/unit/runtime/zero/test_zero_autocast_mixed_dtype.py @@ -51,10 +51,10 @@ def __init__(self, hidden_dim): def attach_fp32_adapter(self, rank): device = get_accelerator().current_device_name() - self.adapter_a = torch.nn.Parameter(torch.randn(rank, self.hidden_dim, device=device, dtype=torch.float32) * - 0.01) - self.adapter_b = torch.nn.Parameter(torch.randn(self.hidden_dim, rank, device=device, dtype=torch.float32) * - 0.01) + self.adapter_a = torch.nn.Parameter( + torch.randn(rank, self.hidden_dim, device=device, dtype=torch.float32) * 0.01) + self.adapter_b = torch.nn.Parameter( + torch.randn(self.hidden_dim, rank, device=device, dtype=torch.float32) * 0.01) assert hasattr(self.base_weight, "convert_to_zero_parameters") self.base_weight.convert_to_zero_parameters([self.adapter_a, self.adapter_b])