The distributed greedy sampling kernel (topk → all_gather → tensor_apis.topk → dynamic indexing ids_all[0, traced_device_idx]) returns incorrect token IDs for most
vocab_per_device tensor shapes when TP > 1.
The bug is shape-dependent and model-independent:
- Only vocab_per_device = 18992 and 37984 produce correct results (these happen to be Qwen3's native sizes)
- All other tested sizes (1024, 4096, 32768, 62080) return wrong values
- TP=1 (no all_gather) works correctly for all sizes
- Confirmed on both Qwen3-30B-A3B and Qwen3.5-35B-A3B with identical sampling code
- RMSNorm formula (w*x vs (1+w)*x) has no effect
The kernel always returns a fixed incorrect index, suggesting the compiled dynamic gather instruction resolves to a wrong memory offset for non-working shapes.
Minimal repro (random tensors, no model needed):
Download this file to nkipy folder on Trn2 repro_minimal_topk_bug.py
source .venv/bin/activate
cp repro_minimal_topk_bug.py examples/models/qwen3_5/
uv run torchrun --nproc_per_node 4 examples/models/qwen3_5/repro_minimal_topk_bug.py
Workaround: Compute logits on device, read back to CPU, do argmax via torch.distributed.all_gather + torch.argmax.
The distributed greedy sampling kernel (topk → all_gather → tensor_apis.topk → dynamic indexing ids_all[0, traced_device_idx]) returns incorrect token IDs for most
vocab_per_device tensor shapes when TP > 1.
The bug is shape-dependent and model-independent:
The kernel always returns a fixed incorrect index, suggesting the compiled dynamic gather instruction resolves to a wrong memory offset for non-working shapes.
Minimal repro (random tensors, no model needed):
Download this file to nkipy folder on Trn2 repro_minimal_topk_bug.py
Workaround: Compute logits on device, read back to CPU, do argmax via torch.distributed.all_gather + torch.argmax.