Skip to content

greedy_sampling kernel returns wrong token IDs with NkiPy on Trn2 due to topk + all_gather + dynamic indexing bug (shape-dependent) #52

@YantaoShen

Description

@YantaoShen

The distributed greedy sampling kernel (topk → all_gather → tensor_apis.topk → dynamic indexing ids_all[0, traced_device_idx]) returns incorrect token IDs for most
vocab_per_device tensor shapes when TP > 1.

The bug is shape-dependent and model-independent:

  • Only vocab_per_device = 18992 and 37984 produce correct results (these happen to be Qwen3's native sizes)
  • All other tested sizes (1024, 4096, 32768, 62080) return wrong values
  • TP=1 (no all_gather) works correctly for all sizes
  • Confirmed on both Qwen3-30B-A3B and Qwen3.5-35B-A3B with identical sampling code
  • RMSNorm formula (w*x vs (1+w)*x) has no effect

The kernel always returns a fixed incorrect index, suggesting the compiled dynamic gather instruction resolves to a wrong memory offset for non-working shapes.

Minimal repro (random tensors, no model needed):
Download this file to nkipy folder on Trn2 repro_minimal_topk_bug.py

source .venv/bin/activate
cp repro_minimal_topk_bug.py  examples/models/qwen3_5/
uv run torchrun --nproc_per_node 4 examples/models/qwen3_5/repro_minimal_topk_bug.py

Workaround: Compute logits on device, read back to CPU, do argmax via torch.distributed.all_gather + torch.argmax.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions