Skip to content

Commit 925f428

Browse files
hgt312claude
andauthored
add explicit CC parameters to DeviceKernel.compile_and_load (#45)
* feat: add explicit cc_enabled/rank_id/world_size to DeviceKernel.compile_and_load Support MPMD workloads and non-torch-distributed runtimes by allowing callers to pass CC parameters explicitly. When cc_enabled is set, every rank traces and compiles independently (no rank-0 broadcast or barrier). Build directories are namespaced by rank to avoid concurrent write collisions. - cc_enabled=None (default): auto-detect from torch.distributed (SPMD) - cc_enabled=True: explicit CC with per-rank compilation (MPMD) - cc_enabled=False: disable CC even in distributed settings Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: separate compilation strategy (is_spmd) from CC parameters Address review feedback by introducing is_spmd flag to control compilation strategy (rank-0 broadcast vs every-rank), keeping cc_enabled/rank_id/world_size for load-time CC only. Simplifies resolution logic with resolved_* locals. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address review — MPMD build dir, SPMD barrier, CC validation - MPMD build dir now auto-namespaces by dist.get_rank() when rank_id not explicitly provided (fixes concurrent write conflict) - SPMD barrier fires unconditionally when distributed, not only when cc_enabled is None (fixes filesystem visibility race) - Validate rank_id/world_size are provided when cc_enabled=True without torch.distributed (raises ValueError instead of passing None) - Add tests: non-rank-0 SPMD worker, MPMD auto-namespace, validation - Update docs for build dir auto-detection Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a0a2c8d commit 925f428

4 files changed

Lines changed: 507 additions & 10 deletions

File tree

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ tutorials/index
3535
3636
user_guide/indexing_slicing_reference
3737
user_guide/tracing_architecture
38+
user_guide/distributed_execution
3839
```
3940

4041
```{toctree}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Distributed Execution
2+
3+
NKIPy supports multi-device execution with collective communication (CC)
4+
through `DeviceKernel.compile_and_load`. This guide covers the three
5+
execution patterns and when to use each.
6+
7+
## Execution Patterns
8+
9+
### 1. SPMD (default)
10+
11+
When `torch.distributed` is initialized and `is_spmd=True` (the default),
12+
rank 0 traces and compiles the kernel, then broadcasts the NEFF path to all
13+
workers. All ranks load the same NEFF with CC enabled.
14+
15+
```python
16+
import torch.distributed as dist
17+
18+
dist.init_process_group(...)
19+
20+
kernel = DeviceKernel.compile_and_load(my_kernel, input_a, input_b)
21+
```
22+
23+
Use this when every rank runs the **same kernel** with the **same input shapes**.
24+
25+
### 2. MPMD (`is_spmd=False`)
26+
27+
Set `is_spmd=False` so every rank traces and compiles independently. This is
28+
required when different ranks run different kernels or different input shapes.
29+
30+
```python
31+
# With torch.distributed (CC auto-detected)
32+
kernel = DeviceKernel.compile_and_load(
33+
my_kernel, input_a, input_b,
34+
is_spmd=False,
35+
)
36+
37+
# Without torch.distributed (explicit CC)
38+
kernel = DeviceKernel.compile_and_load(
39+
my_kernel, input_a, input_b,
40+
is_spmd=False,
41+
cc_enabled=True,
42+
rank_id=my_rank,
43+
world_size=total_workers,
44+
)
45+
```
46+
47+
### 3. No CC (single device or explicit opt-out)
48+
49+
Without `torch.distributed` and without explicit CC parameters, the kernel
50+
loads for single-device execution. You can also pass `cc_enabled=False` to
51+
explicitly disable CC even when `torch.distributed` is active.
52+
53+
```python
54+
# Single device (no torch.distributed)
55+
kernel = DeviceKernel.compile_and_load(my_kernel, input_a)
56+
57+
# Opt out of CC in a distributed setting
58+
kernel = DeviceKernel.compile_and_load(my_kernel, input_a, cc_enabled=False)
59+
```
60+
61+
## Parameter Reference
62+
63+
| Parameter | Controls | Values |
64+
|--------------|--------------------|-----------------------------------------------|
65+
| `is_spmd` | Compilation | `True` = rank-0 broadcast, `False` = all rank |
66+
| `cc_enabled` | CC at load time | `None` = auto, `True` = on, `False` = off |
67+
| `rank_id` | Rank for CC load | `None` = auto from dist, or explicit `int` |
68+
| `world_size` | World size for CC | `None` = auto from dist, or explicit `int` |
69+
70+
## Comparison
71+
72+
| Setting | SPMD (default) | MPMD | No CC |
73+
|-------------------------|-------------------------|----------------------|---------------|
74+
| `is_spmd` | `True` | `False` | Either |
75+
| `cc_enabled` | `None` (auto) | `None`/`True` | `False`/`None`|
76+
| `torch.distributed` | Required | Optional | N/A |
77+
| Compilation | Rank 0 only + broadcast | Every rank | Every rank |
78+
| Barrier | Yes | No | No |
79+
| Use case | Same kernel, all ranks | Per-rank kernels | Single device |
80+
81+
## Build Directory Isolation
82+
83+
In MPMD mode (`is_spmd=False`), the build directory is automatically
84+
namespaced by rank (e.g. `build_dir/rank_0/`, `build_dir/rank_1/`) to
85+
prevent concurrent writes when different ranks produce the same content hash.
86+
The rank is taken from the explicit `rank_id` parameter, or auto-detected
87+
from `torch.distributed` when available.
88+
89+
## Caching
90+
91+
Compiled NEFFs are cached in memory by a content hash of the HLO and compiler
92+
arguments. The cache key is the same regardless of CC mode, so a kernel
93+
compiled once can be reused across calls. Pass `use_cached_if_exists=False` to
94+
force recompilation.

nkipy/src/nkipy/runtime/device_kernel.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,25 @@ def compile_and_load(
7878
use_cached_if_exists=True,
7979
build_dir=None,
8080
target=CompilationTarget.DEFAULT,
81+
is_spmd=True,
82+
cc_enabled=None,
83+
rank_id=None,
84+
world_size=None,
8185
**kwargs,
8286
):
8387
"""Compile and load a kernel, returning a DeviceKernel instance.
8488
85-
In distributed mode, only the lead worker (rank 0) traces and compiles.
86-
The resulting paths are broadcast to all workers, which then load the
87-
NEFF collectively.
89+
Compilation strategy is controlled by ``is_spmd``:
90+
91+
* **True (default)** – rank 0 traces/compiles and broadcasts the NEFF
92+
path to all workers. Requires ``torch.distributed``.
93+
* **False** – every rank traces and compiles independently (MPMD).
94+
Required when each rank runs a different kernel or uses different
95+
input shapes. Works with or without ``torch.distributed``.
96+
97+
Collective-communication at load time is controlled separately by
98+
``cc_enabled``, ``rank_id``, and ``world_size``. When left as
99+
``None`` these are auto-detected from ``torch.distributed``.
88100
89101
Args:
90102
kernel: The kernel function to compile
@@ -93,6 +105,12 @@ def compile_and_load(
93105
use_cached_if_exists: If True, use cached neff if it exists.
94106
build_dir: Overriding the build directory for the kernel
95107
target: Compilation target for the kernel
108+
is_spmd: If True, rank 0 compiles and broadcasts (SPMD).
109+
If False, every rank compiles independently (MPMD).
110+
cc_enabled: Enable collective communication for this kernel.
111+
Auto-detected from torch.distributed when None.
112+
rank_id: Worker rank for CC. Auto-detected when None.
113+
world_size: Total workers for CC. Auto-detected when None.
96114
*args, **kwargs: Arguments for specialization (numpy array or DeviceTensor)
97115
98116
Returns:
@@ -103,7 +121,23 @@ def compile_and_load(
103121

104122
distributed = _is_distributed()
105123

106-
if distributed:
124+
# In MPMD mode, namespace build dir by rank to avoid concurrent writes
125+
# when different ranks produce the same content hash.
126+
if not is_spmd:
127+
effective_rank = rank_id if rank_id is not None else (
128+
dist.get_rank() if distributed else None
129+
)
130+
if effective_rank is not None:
131+
compile_build_dir = os.path.join(
132+
build_dir or _get_build_dir(), f"rank_{effective_rank}"
133+
)
134+
else:
135+
compile_build_dir = build_dir
136+
else:
137+
compile_build_dir = build_dir
138+
139+
# --- 1. Compilation ---
140+
if is_spmd and distributed:
107141
if dist.get_rank() == 0:
108142
neff_path, cache_key = cls._trace_and_compile(
109143
kernel,
@@ -112,7 +146,7 @@ def compile_and_load(
112146
kwargs,
113147
additional_compiler_args=additional_compiler_args,
114148
use_cached_if_exists=use_cached_if_exists,
115-
build_dir=build_dir,
149+
build_dir=compile_build_dir,
116150
target=target,
117151
)
118152
dist.broadcast_object_list([neff_path, cache_key], src=0)
@@ -128,7 +162,7 @@ def compile_and_load(
128162
kwargs,
129163
additional_compiler_args=additional_compiler_args,
130164
use_cached_if_exists=use_cached_if_exists,
131-
build_dir=build_dir,
165+
build_dir=compile_build_dir,
132166
target=target,
133167
)
134168

@@ -137,15 +171,35 @@ def compile_and_load(
137171
logger.info(f"Using loaded kernel: {name} (cache_key={cache_key})")
138172
return _LOADED_KERNELS[cache_key]
139173

140-
# Load the compiled NEFF
141-
if distributed:
174+
# --- 2. Resolve CC parameters for loading ---
175+
resolved_cc = cc_enabled if cc_enabled is not None else distributed
176+
resolved_rank = (
177+
rank_id if rank_id is not None
178+
else (dist.get_rank() if distributed else None)
179+
)
180+
resolved_world = (
181+
world_size if world_size is not None
182+
else (dist.get_world_size() if distributed else None)
183+
)
184+
185+
if resolved_cc and (resolved_rank is None or resolved_world is None):
186+
raise ValueError(
187+
"rank_id and world_size are required when cc_enabled=True "
188+
"and torch.distributed is not available for auto-detection"
189+
)
190+
191+
# Barrier only needed in SPMD mode (rank 0 compiled for everyone)
192+
if is_spmd and distributed:
142193
dist.barrier()
194+
195+
# --- 3. Load the compiled NEFF ---
196+
if resolved_cc:
143197
device_kernel = cls.load_from_neff(
144198
neff_path,
145199
name=name,
146200
cc_enabled=True,
147-
rank_id=dist.get_rank(),
148-
world_size=dist.get_world_size(),
201+
rank_id=resolved_rank,
202+
world_size=resolved_world,
149203
)
150204
else:
151205
device_kernel = cls.load_from_neff(neff_path, name=name)

0 commit comments

Comments
 (0)