diff --git a/docs/.nav.yml b/docs/.nav.yml index e9ebaf0..eb9c16d 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -20,6 +20,8 @@ nav: - General: - contributing/* - contributing/operator-doc-template.md + - Distributed: + - distributed/* - Design Documents: - design/* - Benchmarking: diff --git a/docs/distributed/deterministic_allreduce.md b/docs/distributed/deterministic_allreduce.md new file mode 100644 index 0000000..86f3d59 --- /dev/null +++ b/docs/distributed/deterministic_allreduce.md @@ -0,0 +1,100 @@ +# Deterministic All-Reduce + +RL-Kernel provides a small all-reduce helper for distributed smoke tests and +future WS2 integration work. It has two modes: + +- `torch_all_reduce`: calls `torch.distributed.all_reduce`. +- `ordered_rank_reference`: gathers all rank tensors, accumulates them on process-group rank 0 in process-group rank order, then broadcasts the result. + +The helper reduces the input tensor in place and returns it. + +## Contract + +Results are expected to be stable only when the world size, process-group rank +order, inputs, dtype, operation, backend, and environment are unchanged. + +`op="mean"` performs a sum and divides by world size at a fixed point. Integer +tensors are rejected for `mean`. + +## Ordered-Rank Reference + +`ordered_rank_reference` is a reference path, not a high-performance transport. +It uses `all_gather` and `broadcast`, so the active backend must support those +collectives for the tensor device. The operation order is: + +1. make each rank input contiguous; +2. gather tensors in process-group rank order; +3. accumulate on process-group rank 0 in that order; +4. optionally accumulate floating-point inputs in FP32; +5. divide once for `op="mean"`; +6. broadcast from process-group rank 0. + +This mode is meant for small tensors in tests, debug runs, and reference +comparisons. + +## Torch All-Reduce + +`torch_all_reduce` is a thin wrapper around `torch.distributed.all_reduce`. For +NCCL runs, callers may set best-effort ring settings before process-group +initialization: + +```python +from rl_engine.distributed import configure_deterministic_nccl_env + +configure_deterministic_nccl_env(overwrite=True) +``` + +The helper writes: + +```bash +NCCL_ALGO=Ring +NCCL_PROTO=Simple +NCCL_MIN_NCHANNELS=1 +NCCL_MAX_NCHANNELS=1 +``` + +These settings do not prove bitwise determinism. Validate on the target machine +before making a hardware-specific claim. + +## Behavior + +- `world_size == 1`: returns the input tensor unchanged. +- no initialized process group and `WORLD_SIZE <= 1`: returns the input tensor unchanged. +- no initialized process group and `WORLD_SIZE > 1`: raises `RuntimeError`. +- `async_op=True`: raises `NotImplementedError`. + +## Smoke Tests + +Unit and CPU/Gloo smoke checks: + +```bash +PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/distributed/test_deterministic_allreduce.py -q +``` + +Manual NCCL all-reduce smoke: + +```bash +CUDA_VISIBLE_DEVICES=0,1 \ +NCCL_ALGO=Ring \ +NCCL_PROTO=Simple \ +NCCL_MIN_NCHANNELS=1 \ +NCCL_MAX_NCHANNELS=1 \ +torchrun --standalone --nproc_per_node=2 \ + tests/distributed/test_deterministic_allreduce.py \ + --backend nccl --mode torch_all_reduce --dtype fp32 --device cuda +``` + +DP gradient smoke compares a fixed DP=1 full-batch gradient with DP=N local +gradients reduced by this helper: + +```bash +torchrun --standalone --nproc_per_node=2 \ + tests/distributed/test_dp_gradient_determinism.py \ + --backend gloo --mode ordered_rank_reference --dtype fp32 --device cpu +``` + +## Limitations + +- NVLS / NVLink-Sharp is not implemented or claimed here. +- Multi-node and RDMA behavior are not validated here. +- DeepSpeed gradient synchronization is not controlled by this helper yet. diff --git a/docs/distributed/deterministic_allreduce_audit.md b/docs/distributed/deterministic_allreduce_audit.md new file mode 100644 index 0000000..8ca9edc --- /dev/null +++ b/docs/distributed/deterministic_allreduce_audit.md @@ -0,0 +1,46 @@ +# Deterministic All-Reduce Audit + +This audit records the distributed communication points relevant to +[RL-Align/RL-Kernel#112](https://github.com/RL-Align/RL-Kernel/issues/112). + +## Search + +```bash +rg -n "all_reduce|allreduce|reduce_scatter|all_gather|DistributedDataParallel|FSDP|deepspeed|gradient" \ + rl_engine csrc tests examples benchmarks scripts docs .github +rg -n "torch\.distributed|distributed|dist\.|process_group|ProcessGroup|nccl|NCCL|reduce|all_reduce|reduce_scatter|all_gather|gradient" \ + rl_engine csrc tests examples benchmarks scripts docs .github +``` + +## Summary + +No direct `torch.distributed` all-reduce, reduce-scatter, all-gather, DDP, or +FSDP call sites were found in RL-Kernel source code. The current DP-gradient +communication risk is indirect: `DeepSpeedTrainingWorker` delegates backward and +optimizer behavior to the optional DeepSpeed engine. + +CUDA IPC uses of `torch.multiprocessing.reductions.reduce_tensor` are not +collective reductions. They serialize CUDA IPC handles for same-node weight +handoff. + +## Inventory + +| Location | Kind | In scope for #112 | Handling | +| --- | --- | --- | --- | +| `rl_engine/executors/deepspeed_trainer.py` `DeepSpeedTrainingWorker.train` | Backward / optimizer delegation to DeepSpeed | Yes, indirectly | Do not claim control over DeepSpeed communication order until a tested integration point exists. | +| `rl_engine/executors/deepspeed_trainer.py` `deepspeed.initialize(...)` | Optional distributed runtime setup | Yes, indirectly | Keep missing-DeepSpeed behavior explicit. Any future integration must document the DeepSpeed hook used for gradient reduction. | +| `tests/test_deepspeed_training_worker.py` fake engine tests | Unit tests for worker delegation | Adjacent | These tests prove delegation only; they do not validate distributed gradient ordering. | +| `rl_engine/executors/bridge.py` CUDA IPC `reduce_tensor` use | CUDA IPC handle serialization | No | Keep out of all-reduce scope. | +| `rl_engine/executors/bridge.py` multi-node/RDMA/NCCL transport blockers | Unsupported weight transport guards | Adjacent | Preserve explicit blockers until a tested transport exists. | +| `rl_engine/utils/logger.py` `info_on_rank` | Rank-filtered logging | No | No numeric reduction behavior. | + +## Entry Point + +New distributed code should route through `rl_engine.distributed` so the +all-reduce contract and fallback/reference behavior stay testable in one place. + +## Not Covered + +- NVLS / NVLink-Sharp. +- Multi-node or RDMA collectives. +- DeepSpeed internal gradient synchronization order. diff --git a/rl_engine/distributed/__init__.py b/rl_engine/distributed/__init__.py new file mode 100644 index 0000000..45d29e1 --- /dev/null +++ b/rl_engine/distributed/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from rl_engine.distributed.deterministic_allreduce import ( + DETERMINISTIC_NCCL_ENV, + DeterministicAllReduceConfig, + configure_deterministic_nccl_env, + deterministic_all_reduce, +) + +__all__ = [ + "DETERMINISTIC_NCCL_ENV", + "DeterministicAllReduceConfig", + "configure_deterministic_nccl_env", + "deterministic_all_reduce", +] diff --git a/rl_engine/distributed/deterministic_allreduce.py b/rl_engine/distributed/deterministic_allreduce.py new file mode 100644 index 0000000..6e2de84 --- /dev/null +++ b/rl_engine/distributed/deterministic_allreduce.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import os +import warnings +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +import torch.distributed as dist + +DETERMINISTIC_NCCL_ENV = { + "NCCL_ALGO": "Ring", + "NCCL_PROTO": "Simple", + "NCCL_MIN_NCHANNELS": "1", + "NCCL_MAX_NCHANNELS": "1", +} + + +@dataclass(frozen=True) +class DeterministicAllReduceConfig: + """Options for :func:`deterministic_all_reduce`.""" + + mode: Literal["torch_all_reduce", "ordered_rank_reference"] = "torch_all_reduce" + op: Literal["sum", "mean"] = "sum" + force_fp32_accumulation: bool = True + async_op: bool = False + group: Optional[dist.ProcessGroup] = None + + +def configure_deterministic_nccl_env(*, overwrite: bool = False) -> dict[str, Optional[str]]: + """Set best-effort NCCL ring settings before process-group init.""" + + if dist.is_available() and dist.is_initialized(): + warnings.warn( + "NCCL environment was configured after torch.distributed initialization", + RuntimeWarning, + stacklevel=2, + ) + + previous: dict[str, Optional[str]] = {} + for key, value in DETERMINISTIC_NCCL_ENV.items(): + previous[key] = os.environ.get(key) + if overwrite or key not in os.environ: + os.environ[key] = value + continue + if os.environ[key] != value: + warnings.warn( + f"{key} is {os.environ[key]!r}; expected {value!r}", + RuntimeWarning, + stacklevel=2, + ) + return previous + + +def deterministic_all_reduce( + tensor: torch.Tensor, + config: Optional[DeterministicAllReduceConfig] = None, +) -> torch.Tensor: + """Reduce ``tensor`` in place and return it.""" + + cfg = config or DeterministicAllReduceConfig() + _validate(tensor, cfg) + + if cfg.async_op: + raise NotImplementedError("async deterministic all-reduce is not implemented") + if not dist.is_available(): + raise RuntimeError("torch.distributed is unavailable") + if not dist.is_initialized(): + if int(os.environ.get("WORLD_SIZE", "1")) > 1: + raise RuntimeError("torch.distributed is not initialized") + return tensor + + world_size = dist.get_world_size(group=cfg.group) + if world_size == 1: + return tensor + + if cfg.mode == "torch_all_reduce": + return _torch_all_reduce(tensor, cfg, world_size) + return _ordered_rank_reference(tensor, cfg, world_size) + + +def _validate(tensor: torch.Tensor, cfg: DeterministicAllReduceConfig) -> None: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"tensor must be a torch.Tensor, got {type(tensor)!r}") + if cfg.mode not in {"torch_all_reduce", "ordered_rank_reference"}: + raise ValueError(f"unsupported all-reduce mode: {cfg.mode!r}") + if cfg.op not in {"sum", "mean"}: + raise ValueError(f"unsupported reduction op: {cfg.op!r}") + if cfg.op == "mean" and not (tensor.is_floating_point() or tensor.is_complex()): + raise TypeError("op='mean' requires a floating-point or complex tensor") + + +def _torch_all_reduce( + tensor: torch.Tensor, + cfg: DeterministicAllReduceConfig, + world_size: int, +) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=cfg.group, async_op=False) + if cfg.op == "mean": + tensor.div_(world_size) + return tensor + + +def _ordered_rank_reference( + tensor: torch.Tensor, + cfg: DeterministicAllReduceConfig, + world_size: int, +) -> torch.Tensor: + send = tensor.detach().contiguous() + gathered = [torch.empty_like(send) for _ in range(world_size)] + dist.all_gather(gathered, send, group=cfg.group) + + result = torch.empty_like(send) + if dist.get_rank(group=cfg.group) == 0: + dtype = _accumulation_dtype(send, cfg.force_fp32_accumulation) + reduced = gathered[0].to(dtype=dtype) + for item in gathered[1:]: + reduced.add_(item.to(dtype=dtype)) + if cfg.op == "mean": + reduced.div_(world_size) + result.copy_(reduced.to(dtype=send.dtype)) + + dist.broadcast(result, src=_group_root_global_rank(cfg.group), group=cfg.group) + tensor.copy_(result.view_as(tensor)) + return tensor + + +def _group_root_global_rank(group: Optional[dist.ProcessGroup]) -> int: + if group is None: + return 0 + try: + return int(dist.get_global_rank(group, 0)) + except AttributeError as exc: + raise RuntimeError( + "custom process groups require torch.distributed.get_global_rank" + ) from exc + + +def _accumulation_dtype(tensor: torch.Tensor, force_fp32: bool) -> torch.dtype: + if not force_fp32 or not tensor.is_floating_point(): + return tensor.dtype + if tensor.dtype == torch.float64: + return torch.float64 + return torch.float32 diff --git a/tests/distributed/test_deterministic_allreduce.py b/tests/distributed/test_deterministic_allreduce.py new file mode 100644 index 0000000..9a4c71b --- /dev/null +++ b/tests/distributed/test_deterministic_allreduce.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist + +from rl_engine.distributed import ( + DETERMINISTIC_NCCL_ENV, + DeterministicAllReduceConfig, + configure_deterministic_nccl_env, + deterministic_all_reduce, +) + + +def test_configure_deterministic_nccl_env_preserves_existing_value(monkeypatch): + monkeypatch.setenv("NCCL_ALGO", "Tree") + with pytest.warns(RuntimeWarning, match="NCCL_ALGO"): + previous = configure_deterministic_nccl_env() + + assert previous["NCCL_ALGO"] == "Tree" + assert os.environ["NCCL_ALGO"] == "Tree" + for key, value in DETERMINISTIC_NCCL_ENV.items(): + if key != "NCCL_ALGO": + assert os.environ[key] == value + + +def test_configure_deterministic_nccl_env_can_overwrite(monkeypatch): + monkeypatch.setenv("NCCL_ALGO", "Tree") + configure_deterministic_nccl_env(overwrite=True) + + assert os.environ["NCCL_ALGO"] == "Ring" + + +def test_single_process_without_process_group_is_noop(): + tensor = torch.tensor([1.0, 2.0, 3.0]) + + reduced = deterministic_all_reduce( + tensor, + DeterministicAllReduceConfig(mode="ordered_rank_reference", op="mean"), + ) + + assert reduced is tensor + assert torch.equal(tensor, torch.tensor([1.0, 2.0, 3.0])) + + +def test_ordered_rank_reference_gloo_smoke_runs_under_torchrun(): + if not (dist.is_available() and dist.is_gloo_available()): + pytest.skip("Gloo is unavailable") + + repo = Path(__file__).resolve().parents[2] + env = os.environ.copy() + env["PYTHONPATH"] = f"{repo}{os.pathsep}{env.get('PYTHONPATH', '')}" + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node=2", + str(Path(__file__).resolve()), + "--backend", + "gloo", + "--mode", + "ordered_rank_reference", + "--dtype", + "fp32", + "--device", + "cpu", + "--iterations", + "2", + ] + completed = subprocess.run( + cmd, + cwd=repo, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=120, + check=False, + ) + + assert completed.returncode == 0, completed.stdout + assert '"status": "pass"' in completed.stdout + + +def test_ordered_rank_reference_reverse_group_runs_under_torchrun(): + if not (dist.is_available() and dist.is_gloo_available()): + pytest.skip("Gloo is unavailable") + + repo = Path(__file__).resolve().parents[2] + env = os.environ.copy() + env["PYTHONPATH"] = f"{repo}{os.pathsep}{env.get('PYTHONPATH', '')}" + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node=2", + str(Path(__file__).resolve()), + "--backend", + "gloo", + "--mode", + "ordered_rank_reference", + "--dtype", + "fp32", + "--device", + "cpu", + "--iterations", + "2", + "--reverse-group", + ] + completed = subprocess.run( + cmd, + cwd=repo, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=120, + check=False, + ) + + assert completed.returncode == 0, completed.stdout + assert '"status": "pass"' in completed.stdout + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="all-reduce smoke test") + parser.add_argument("--backend", choices=("gloo", "nccl"), default="gloo") + parser.add_argument( + "--mode", + choices=("ordered_rank_reference", "torch_all_reduce"), + default="ordered_rank_reference", + ) + parser.add_argument("--op", choices=("sum", "mean"), default="sum") + parser.add_argument("--dtype", choices=("fp32", "fp16", "bf16"), default="fp32") + parser.add_argument("--device", choices=("auto", "cpu", "cuda"), default="auto") + parser.add_argument("--numel", type=int, default=257) + parser.add_argument("--iterations", type=int, default=3) + parser.add_argument("--configure-nccl-env", action="store_true") + parser.add_argument("--reverse-group", action="store_true") + parser.add_argument("--rtol", type=float, default=None) + parser.add_argument("--atol", type=float, default=None) + return parser.parse_args() + + +def _dtype(name: str) -> torch.dtype: + return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[name] + + +def _device(args: argparse.Namespace) -> torch.device: + if args.device == "cpu" or args.backend == "gloo": + return torch.device("cpu") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but unavailable") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + return torch.device("cuda", local_rank) + + +def _make_input(rank: int, dtype: torch.dtype, device: torch.device, numel: int) -> torch.Tensor: + base = torch.arange(numel, dtype=torch.float32, device=device) + values = ((base % 17) - 8.0) / 17.0 + return (values + (rank + 1) * 0.03125).to(dtype=dtype) + + +def _group_rank_order(group: dist.ProcessGroup | None, device: torch.device) -> list[int]: + rank = torch.tensor([dist.get_rank()], dtype=torch.int64, device=device) + gathered = [torch.empty_like(rank) for _ in range(dist.get_world_size(group=group))] + dist.all_gather(gathered, rank, group=group) + return [int(item.item()) for item in gathered] + + +def _expected_reduce( + rank_order: list[int], + dtype: torch.dtype, + device: torch.device, + numel: int, + op: str, +) -> torch.Tensor: + acc_dtype = torch.float32 if dtype != torch.float64 else torch.float64 + reduced = _make_input(rank_order[0], dtype, device, numel).to(dtype=acc_dtype) + for rank in rank_order[1:]: + reduced.add_(_make_input(rank, dtype, device, numel).to(dtype=acc_dtype)) + if op == "mean": + reduced.div_(len(rank_order)) + return reduced.to(dtype=dtype) + + +def _tolerances(dtype: torch.dtype, args: argparse.Namespace) -> tuple[float, float]: + if args.atol is not None and args.rtol is not None: + return args.atol, args.rtol + if dtype == torch.float32: + return (0.0, 0.0) if args.mode == "ordered_rank_reference" else (1.0e-6, 0.0) + if dtype == torch.bfloat16: + return 8.0e-3, 8.0e-3 + return 2.0e-3, 2.0e-3 + + +def _diff_stats(actual: torch.Tensor, expected: torch.Tensor) -> dict[str, Any]: + actual_f32 = actual.detach().to(torch.float32).cpu() + expected_f32 = expected.detach().to(torch.float32).cpu() + diff = (actual_f32 - expected_f32).abs() + rel = diff / expected_f32.abs().clamp_min(1.0e-12) + return { + "bitwise_equal": bool(torch.equal(actual.detach().cpu(), expected.detach().cpu())), + "max_abs_diff": float(diff.max().item()), + "max_rel_diff": float(rel.max().item()), + "mismatch_count": int((diff != 0).sum().item()), + } + + +def _assert_close(actual: torch.Tensor, expected: torch.Tensor, atol: float, rtol: float) -> None: + if not torch.allclose(actual, expected, atol=atol, rtol=rtol): + stats = _diff_stats(actual, expected) + raise AssertionError(f"all-reduce mismatch: {stats}") + + +def _run_distributed_smoke(args: argparse.Namespace) -> None: + if args.configure_nccl_env or (args.backend == "nccl" and args.mode == "torch_all_reduce"): + configure_deterministic_nccl_env() + + device = _device(args) + dist.init_process_group(backend=args.backend) + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + group = None + if args.reverse_group: + group = dist.new_group(ranks=list(reversed(range(world_size)))) + + dtype = _dtype(args.dtype) + atol, rtol = _tolerances(dtype, args) + rank_order = _group_rank_order(group, device) + expected = _expected_reduce(rank_order, dtype, device, args.numel, args.op) + previous: torch.Tensor | None = None + stats: dict[str, Any] = {} + + for _ in range(args.iterations): + candidate = _make_input(rank, dtype, device, args.numel) + reference = candidate.clone() + deterministic_all_reduce( + candidate, + DeterministicAllReduceConfig(mode=args.mode, op=args.op, group=group), + ) + deterministic_all_reduce( + reference, + DeterministicAllReduceConfig( + mode="ordered_rank_reference", + op=args.op, + group=group, + ), + ) + _assert_close(candidate, expected, atol, rtol) + _assert_close(reference, expected, atol, rtol) + if previous is not None: + _assert_close(candidate, previous, atol, rtol) + previous = candidate.clone() + stats = _diff_stats(candidate, expected) + + if rank == 0: + print( + json.dumps( + { + "status": "pass", + "backend": args.backend, + "mode": args.mode, + "op": args.op, + "dtype": args.dtype, + "device": str(device), + "world_size": world_size, + "iterations": args.iterations, + **stats, + }, + sort_keys=True, + ) + ) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + _run_distributed_smoke(_parse_args()) diff --git a/tests/distributed/test_dp_gradient_determinism.py b/tests/distributed/test_dp_gradient_determinism.py new file mode 100644 index 0000000..744e161 --- /dev/null +++ b/tests/distributed/test_dp_gradient_determinism.py @@ -0,0 +1,351 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from rl_engine.distributed import ( + DeterministicAllReduceConfig, + configure_deterministic_nccl_env, + deterministic_all_reduce, +) + + +class TinyGradientModel(torch.nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(input_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.net(inputs) + + +@dataclass(frozen=True) +class GradientStats: + bitwise_equal: bool + max_abs_diff: float + max_rel_diff: float + mismatch_count: int + + +def test_fixed_batch_is_reproducible(): + first = _fixed_batch(global_batch_size=8, input_dim=3, output_dim=2) + second = _fixed_batch(global_batch_size=8, input_dim=3, output_dim=2) + + assert torch.equal(first[0], second[0]) + assert torch.equal(first[1], second[1]) + + +def test_dp_gradient_gloo_smoke_runs_under_torchrun(): + if not (dist.is_available() and dist.is_gloo_available()): + pytest.skip("Gloo is unavailable") + + repo = Path(__file__).resolve().parents[2] + env = os.environ.copy() + env["PYTHONPATH"] = f"{repo}{os.pathsep}{env.get('PYTHONPATH', '')}" + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node=2", + str(Path(__file__).resolve()), + "--backend", + "gloo", + "--mode", + "ordered_rank_reference", + "--dtype", + "fp32", + "--device", + "cpu", + ] + completed = subprocess.run( + cmd, + cwd=repo, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=120, + check=False, + ) + + assert completed.returncode == 0, completed.stdout + assert '"status": "pass"' in completed.stdout + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="DP gradient determinism smoke test") + parser.add_argument("--backend", choices=("gloo", "nccl"), default="gloo") + parser.add_argument( + "--mode", + choices=("ordered_rank_reference", "torch_all_reduce"), + default="ordered_rank_reference", + ) + parser.add_argument("--dtype", choices=("fp32", "fp16", "bf16"), default="fp32") + parser.add_argument("--device", choices=("auto", "cpu", "cuda"), default="auto") + parser.add_argument("--global-batch-size", type=int, default=16) + parser.add_argument("--input-dim", type=int, default=7) + parser.add_argument("--hidden-dim", type=int, default=13) + parser.add_argument("--output-dim", type=int, default=5) + parser.add_argument("--seed", type=int, default=2026) + parser.add_argument("--configure-nccl-env", action="store_true") + parser.add_argument("--rtol", type=float, default=None) + parser.add_argument("--atol", type=float, default=None) + return parser.parse_args() + + +def _dtype(name: str) -> torch.dtype: + return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[name] + + +def _device(args: argparse.Namespace) -> torch.device: + if args.device == "cpu" or args.backend == "gloo": + return torch.device("cpu") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA device requested but torch.cuda.is_available() is false") + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + torch.cuda.set_device(local_rank) + return torch.device("cuda", local_rank) + + +def _set_deterministic_controls(seed: int) -> None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.use_deterministic_algorithms(True, warn_only=True) + + +def _fixed_batch( + *, + global_batch_size: int, + input_dim: int, + output_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + inputs = torch.linspace( + -1.0, + 1.0, + steps=global_batch_size * input_dim, + dtype=torch.float32, + ).reshape(global_batch_size, input_dim) + targets = torch.cos( + torch.linspace( + -0.7, + 0.9, + steps=global_batch_size * output_dim, + dtype=torch.float32, + ) + ).reshape(global_batch_size, output_dim) + return inputs, targets + + +def _make_model( + *, + input_dim: int, + hidden_dim: int, + output_dim: int, + seed: int, + dtype: torch.dtype, + device: torch.device, +) -> TinyGradientModel: + torch.manual_seed(seed) + model = TinyGradientModel(input_dim, hidden_dim, output_dim) + return model.to(device=device, dtype=dtype) + + +def _compute_gradients( + model: torch.nn.Module, + inputs: torch.Tensor, + targets: torch.Tensor, + *, + dtype: torch.dtype, + device: torch.device, +) -> dict[str, torch.Tensor]: + model.zero_grad(set_to_none=True) + batch_inputs = inputs.to(device=device, dtype=dtype) + batch_targets = targets.to(device=device, dtype=dtype) + predictions = model(batch_inputs) + loss = F.mse_loss(predictions.float(), batch_targets.float(), reduction="mean") + loss.backward() + return { + name: parameter.grad.detach().clone() + for name, parameter in model.named_parameters() + if parameter.grad is not None + } + + +def _reduce_gradients(model: torch.nn.Module, mode: str) -> dict[str, torch.Tensor]: + reduced: dict[str, torch.Tensor] = {} + for name, parameter in model.named_parameters(): + if parameter.grad is None: + continue + deterministic_all_reduce( + parameter.grad, + DeterministicAllReduceConfig(mode=mode, op="mean"), + ) + reduced[name] = parameter.grad.detach().clone() + return reduced + + +def _stats(actual: torch.Tensor, expected: torch.Tensor) -> GradientStats: + actual_f32 = actual.detach().to(torch.float32).cpu() + expected_f32 = expected.detach().to(torch.float32).cpu() + diff = (actual_f32 - expected_f32).abs() + rel = diff / expected_f32.abs().clamp_min(1.0e-12) + return GradientStats( + bitwise_equal=bool(torch.equal(actual.detach().cpu(), expected.detach().cpu())), + max_abs_diff=float(diff.max().item()), + max_rel_diff=float(rel.max().item()), + mismatch_count=int((diff != 0).sum().item()), + ) + + +def _tolerances(dtype: torch.dtype, args: argparse.Namespace) -> tuple[float, float]: + if args.atol is not None and args.rtol is not None: + return args.atol, args.rtol + if dtype == torch.float32: + return 1.0e-5, 1.0e-5 + if dtype == torch.bfloat16: + return 2.0e-2, 2.0e-2 + return 5.0e-3, 5.0e-3 + + +def _compare_gradients( + actual: dict[str, torch.Tensor], + expected: dict[str, torch.Tensor], + *, + atol: float, + rtol: float, +) -> tuple[GradientStats, list[dict[str, Any]]]: + if set(actual) != set(expected): + raise AssertionError( + f"gradient key mismatch: actual={sorted(actual)}, expected={sorted(expected)}" + ) + + global_stats = GradientStats(True, 0.0, 0.0, 0) + parameters: list[dict[str, Any]] = [] + for name in sorted(actual): + param_stats = _stats(actual[name], expected[name]) + parameters.append({"name": name, **param_stats.__dict__}) + global_stats = GradientStats( + bitwise_equal=global_stats.bitwise_equal and param_stats.bitwise_equal, + max_abs_diff=max(global_stats.max_abs_diff, param_stats.max_abs_diff), + max_rel_diff=max(global_stats.max_rel_diff, param_stats.max_rel_diff), + mismatch_count=global_stats.mismatch_count + param_stats.mismatch_count, + ) + if not torch.allclose(actual[name], expected[name], atol=atol, rtol=rtol): + raise AssertionError( + f"gradient mismatch for {name}: " + f"max_abs_diff={param_stats.max_abs_diff} " + f"max_rel_diff={param_stats.max_rel_diff} " + f"mismatch_count={param_stats.mismatch_count}" + ) + return global_stats, parameters + + +def _run_distributed_smoke(args: argparse.Namespace) -> None: + if args.configure_nccl_env or (args.backend == "nccl" and args.mode == "torch_all_reduce"): + configure_deterministic_nccl_env() + device = _device(args) + _set_deterministic_controls(args.seed) + dist.init_process_group(backend=args.backend) + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + if args.global_batch_size % world_size != 0: + raise ValueError("global batch size must be divisible by world size") + + dtype = _dtype(args.dtype) + atol, rtol = _tolerances(dtype, args) + inputs, targets = _fixed_batch( + global_batch_size=args.global_batch_size, + input_dim=args.input_dim, + output_dim=args.output_dim, + ) + + baseline_model = _make_model( + input_dim=args.input_dim, + hidden_dim=args.hidden_dim, + output_dim=args.output_dim, + seed=args.seed, + dtype=dtype, + device=device, + ) + baseline_grads = _compute_gradients( + baseline_model, + inputs, + targets, + dtype=dtype, + device=device, + ) + + local_batch_size = args.global_batch_size // world_size + start = rank * local_batch_size + end = start + local_batch_size + dp_model = _make_model( + input_dim=args.input_dim, + hidden_dim=args.hidden_dim, + output_dim=args.output_dim, + seed=args.seed, + dtype=dtype, + device=device, + ) + _compute_gradients( + dp_model, + inputs[start:end], + targets[start:end], + dtype=dtype, + device=device, + ) + reduced_grads = _reduce_gradients(dp_model, args.mode) + global_stats, parameter_stats = _compare_gradients( + reduced_grads, + baseline_grads, + atol=atol, + rtol=rtol, + ) + + if rank == 0: + print( + json.dumps( + { + "status": "pass", + "backend": args.backend, + "mode": args.mode, + "dtype": args.dtype, + "device": str(device), + "world_size": world_size, + "global_batch_size": args.global_batch_size, + "atol": atol, + "rtol": rtol, + **global_stats.__dict__, + "parameters": parameter_stats, + }, + sort_keys=True, + ) + ) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + _run_distributed_smoke(_parse_args())