Skip to content

Custom NKI kernel causes OOB scatter/gather error when compiled inside NxDI model graph (SDK 2.29) #1306

@jimburtoft

Description

@jimburtoft

Description

A custom NKI kernel that passes all standalone tests and works in torch_neuronx.trace() fails with an out-of-bound scatter/gather error when compiled as part of an NxDI (neuronx-distributed-inference) model graph. Through systematic binary search, we've narrowed the issue to kernel complexity: a minimal kernel (DMA-only) and a medium kernel (element-wise + 3 matmuls) both pass, but the full kernel (~20 matmuls + 6-round Neumann loop) fails.

Environment

  • SDK Version: Neuron SDK 2.29 (neuronx-cc 2.24.5133.0, torch-neuronx 2.9.0.2.13, NKI 0.3.0)
  • Instance: trn2.3xlarge (LNC=2, TP=4)
  • NxDI Version: 0.9.17334
  • DLAMI: Deep Learning AMI Neuron (Ubuntu 24.04) 20260410

Error Message

RuntimeError: failed to run scatter/gather (indirect memory copy via vector DGE), due to out-of-bound access

Reproduction Summary

Configuration Result
NKI kernel standalone (S=128,256,512,1024) PASS
NKI kernel in torch_neuronx.trace() (BH=32, S=1024) PASS
NKI kernel in torch_neuronx.trace() 40-layer model with dual NKI kernels PASS
NKI kernel in NxDI model (4 decoder layers, 3 DeltaNet + 1 GQA) FAIL (OOB)
NKI kernel in NxDI model (40 decoder layers) FAIL (OOB)
Same model, same layers, PyTorch chunk_forward(64) instead of NKI kernel PASS
NKI kernel in NxDI model, vector_dynamic_offsets flag REMOVED FAIL (OOB)

Binary Search Results

We performed a systematic binary search by reducing kernel complexity while keeping the same NxDI integration path:

Kernel Variant nc_matmul calls SBUF buffers Neumann loop Result in NxDI
Passthrough (DMA load → DMA store only) 0 ~4 No PASS (634ms TTFT)
Phase 1+2 (elem-wise + 3 matmuls) 3 ~20 No PASS (633ms TTFT)
Full kernel (all phases, nl.sequential_range(6)) ~20 ~40+ Yes (6 rounds) FAIL (OOB)
Full kernel (all phases, nl.static_range(6)) ~20 ~40+ Yes (unrolled) FAIL (OOB)

Key findings from binary search:

  1. The issue is NOT caused by nl.sequential_range — changing to nl.static_range(6) (compile-time unrolled) still OOBs
  2. The issue is NOT caused by nl.shared_hbm vs nl.hbmnl.hbm is not allowed for kernel outputs ("must be in shared_hbm address space")
  3. The issue appears to be triggered by kernel complexity — specifically the combination of ~20 nc_matmul calls with ~40+ SBUF buffer allocations in a single kernel invocation, when compiled inside an NxDI model graph
  4. The exact same kernel works standalone and with torch_neuronx.trace() even in 40-layer models

What Was Tried

  1. Fused kernel (all chunks processed in one NKI call with nl.sequential_range) — OOB
  2. Per-chunk kernel (no sequential_range, each call gets full 128x128 tiles) — OOB
  3. Fused kernel with .unbind() pre-slicing — OOB
  4. Per-chunk kernel on just 1 of 30 DeltaNet layers — OOB
  5. Per-chunk kernel with shard_on_block MoE disabled (PyTorch MoE fallback) — OOB
  6. Removed --internal-enable-dge-levels vector_dynamic_offsets compiler flag — OOB persists
  7. Standalone torch_neuronx.trace() with 40-layer model + dual NKI kernels — PASSES
  8. Passthrough kernel (DMA only, same 10 inputs, 2 outputs) — PASSES in NxDI
  9. Phase 1+2 kernel (elem-wise + 3 matmuls, same inputs/outputs) — PASSES in NxDI
  10. Full kernel with nl.static_range(6) instead of nl.sequential_range(6) — OOB persists

Key Evidence

  • The issue ONLY manifests when using NxDI's model.compile() pipeline. Identical NKI kernels work fine with raw torch_neuronx.trace(), even in 40-layer models with multiple NKI kernel types.
  • The boundary is kernel complexity, not any specific NKI API: 3 matmuls pass, ~20 matmuls fail.
  • Loop type (sequential_range vs static_range) is irrelevant — both fail.
  • Removing the --internal-enable-dge-levels vector_dynamic_offsets flag does NOT resolve the issue.
  • The model also uses NxDI's internal bwmm_shard_on_block NKI kernel for MoE. Even disabling this (falling back to PyTorch MoE) does not resolve the OOB.

Kernel Description

The kernel implements a chunked DeltaNet gated delta rule computation:

  • Input shapes: Q, K, V, beta, g_cumsum, g_last, state_in all (128, 128) float32
  • Also takes lower_mask, identity, lower_mask_diag (128, 128) float32
  • Uses @nki.jit, nisa.nc_matmul, nisa.dma_copy, nisa.tensor_scalar, nisa.tensor_tensor, nisa.activation
  • No neuronxcc.nki usage, no deprecated APIs, pure nki.* namespace
  • Neumann power-doubling for intra-chunk correction (6 rounds)
  • ~20 nc_matmul calls, ~40+ SBUF buffer allocations per invocation

Compiler Flags Used

--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1 --auto-cast=none

Workaround

Use PyTorch chunk_forward(chunk_size=64) instead of the NKI kernel for CTE. This works correctly in NxDI and provides good performance (1,139ms TTFT). The NKI kernel would further reduce TTFT by an estimated 5-10%.

Possible Root Cause

The compiler may have a limit on NKI kernel complexity (SBUF buffer count, nc_matmul count, or total instruction count) when the kernel is embedded in a large NxDI model graph. Below this threshold, the DGE address computation is correct. Above it, the DGE offsets overflow or are miscalculated, causing runtime OOB. The threshold is NOT hit when the same kernel is compiled standalone or via torch_neuronx.trace(), suggesting the NxDI graph context (HBM layout, buffer allocation) shifts the effective addresses enough to trigger the overflow.

Model Context

This is the Qwen3.5-35B-A3B contrib model for NxDI (hybrid DeltaNet + GQA + MoE architecture). The NKI kernel is used for DeltaNet linear attention context encoding in 30 of the 40 decoder layers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions