diff --git a/contrib/models/Qwen3.5-27B/README.md b/contrib/models/Qwen3.5-27B/README.md new file mode 100644 index 00000000..b46c8ccd --- /dev/null +++ b/contrib/models/Qwen3.5-27B/README.md @@ -0,0 +1,299 @@ +# Contrib Model: Qwen3.5-27B + +NeuronX Distributed Inference implementation of Qwen3.5-27B, a 27B parameter dense model from Alibaba Cloud with a hybrid DeltaNet + GQA attention architecture. This is the first NxDI implementation of a model using linear recurrent attention (DeltaNet) with custom NKI kernels. + +## Model Family + +| Model | HuggingFace ID | Params | Instance | +|-------|----------------|--------|----------| +| **Qwen3.5-27B** | `Qwen/Qwen3.5-27B` | 27B | trn2.3xlarge (TP=4) | +| **Qwen3.5-27B-VL** | `Qwen/Qwen3.5-27B-VL` | 27B + ViT | trn2.3xlarge (TP=4) | + +**License:** Apache 2.0 + +## Architecture Details + +| Feature | Value | +|---------|-------| +| Layers | 64 (48 DeltaNet + 16 GQA) | +| Layer Pattern | [3 DeltaNet + 1 GQA] x 16 | +| Hidden Size | 5120 | +| GQA Attention | 24 heads, 4 KV heads, head_dim=256 | +| DeltaNet Attention | 48 value heads, 16 key heads, k_dim=v_dim=128 | +| Dense MLP | SwiGLU (gate_proj + up_proj: 5120 -> 17408, down_proj: 17408 -> 5120) | +| Position Encoding | Partial RoPE (25% of head_dim = 64 dims), mRoPE for VL | +| Vocabulary | 248,320 | +| Normalization | RMSNorm with +1 weight convention | +| Activation | SiLU gated MLP | + +### Unique Architecture Features + +- **Hybrid DeltaNet + GQA:** 48 of 64 layers use Gated DeltaNet (linear recurrent attention), 16 layers use standard GQA with KV cache. The pattern repeats every 4 layers: 3 DeltaNet + 1 GQA. +- **DeltaNet Linear Attention:** Uses the delta rule for recurrent state updates with gated decay. Per-step: `state *= exp(g); delta = (v - state^T @ k) * beta; state += outer(k, delta); output = state^T @ q`. Runs as a chunked algorithm for context encoding, per-token recurrence for token generation. +- **Custom NKI Kernels:** Three NKI kernels implement the DeltaNet forward pass on Neuron: a per-token recurrent kernel (TKG), a per-chunk kernel (legacy), and a fused single-kernel chunked forward (CTE). The fused kernel uses a Neumann series for intra-chunk correction with state persistence in SBUF across chunks. +- **GQA Output Gate:** Attention layers use a sigmoid output gate. `q_proj` is 2x sized and interleaved: `[head0_query | head0_gate | head1_query | ...]`. The gate is split during weight conversion and applied after attention. +- **Partial RoPE:** Only 25% of head_dim (64 of 256 dimensions) receives rotary embeddings. The remaining 192 dimensions are identity (no rotation). +- **+1 RMSNorm Convention:** HF weights use `output = norm(x) * (1 + weight)` where weight is initialized to zeros. Converted to standard `output = norm(x) * weight` during loading by adding 1.0 to all RMSNorm weights (except DeltaNet internal norms, which use standard convention). +- **Vision-Language Support:** Optional ViT encoder runs on CPU (HBM fully consumed by 27B text decoder). Vision embeddings are injected via a scatter mask at traced input positions. + +## Test Results + +### Unit Tests (CPU) + +| Test Module | Tests | Status | +|-------------|-------|--------| +| test_config.py | 26 | 26/26 PASS | +| test_weight_conversion.py | 16 | 16/16 PASS | +| **Total** | **42** | **42/42 PASS** | + +### Integration Test (27B, trn2.3xlarge, TP=4, SDK 2.29) + +| Test | Status | Notes | +|------|--------|-------| +| Model loads | PASS | Compiled + loaded with DeltaNet state aliasing | +| Model generates | PASS | Generates coherent multi-sentence text | +| Output coherence | PASS | 3+ words, no excessive repetition | +| Top token valid | PASS | First token decodable and semantically valid | +| Capital of France | PASS | Produces "Paris" as first token | +| TTFT performance | PASS | ~576 ms (128 input tokens, bs=1) | +| Throughput | PASS | ~18.9 tok/s (bs=1) | +| Multi-prompt generation | PASS | 4/4 prompts produce coherent output | + +**All 50 tests pass (42 unit + 8 integration) on SDK 2.29.** + +### Generation Output (27B, TP=4, seq_len=128, greedy top_k=1) + +**Prompt:** "The capital of France is" + +**Output:** Paris. It is the largest city in France and serves as the country's political, cultural, and economic center... + +**Status:** PASS -- coherent, factually correct, multi-sentence response. + +## Performance Benchmarks + +**SDK 2.29**, BF16, trn2.3xlarge (4 NeuronCores, LNC=2), seq_len=128, bs=1. + +### Text-Only Benchmarks + +| Metric | Value | +|--------|-------| +| **TTFT (P50)** | 576 ms | +| **TPOT (P50)** | 53 ms | +| **Throughput** | 18.9 tok/s | +| Compilation time | ~13 min | +| Weight loading | ~31 s | +| HBM usage | 23.57 GB / 24 GB | + +### Vision-Language Benchmarks (VL pipeline) + +| Metric | Value | +|--------|-------| +| Vision encoder (CPU) | ~918 ms | +| Text generation (Neuron) | ~3.9 s (30 tokens) | +| End-to-end VL | ~4.8 s | + +### NKI Kernel Benchmarks (standalone, single NeuronCore) + +| Kernel | 128 tokens | 256 tokens | 512 tokens | +|--------|-----------|-----------|-----------| +| Fused chunked (CTE) | 335 us | 339 us | 487 us | +| Recurrent (TKG, S=1) | 183 us | - | - | + +### Key Observations + +- **HBM-limited at TP=4:** The 27B model consumes 23.57 GB of 24 GB HBM per NeuronCore pair. Context length limited to 128-512 tokens. Use trn2.12xlarge for longer contexts. +- **DeltaNet enables efficient TKG:** Token generation uses O(1) per-token recurrence instead of O(n) KV cache attention for 48/64 layers, keeping TPOT at 53ms. +- **Vision encoder on CPU:** The ViT runs on CPU because HBM is fully consumed by the text decoder. CPU vision adds ~918ms latency per image. +- **Fused NKI kernel 2.8-5.2% faster:** The fused chunked kernel provides modest TTFT improvement over the PyTorch baseline (larger gains at longer contexts). + +## Usage + +### Text-Only (trn2.3xlarge, TP=4) + +```python +import json +import torch +from transformers import AutoTokenizer, GenerationConfig +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + +model_path = "/path/to/Qwen3.5-27B" +compiled_path = "/scratch/qwen35_traced/" + +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + enable_bucketing=False, + flash_decoding_enabled=False, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + save_sharded_checkpoint=True, +) + +# Read config.json directly (model_type 'qwen3_5' may not be +# registered in all transformers versions) +import os +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) +text_config = hf_config.get("text_config", hf_config) +config_dict = dict(text_config) +config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) +config_dict.setdefault("tie_word_embeddings", False) + +config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, +) + +model = NeuronQwen35ForCausalLM(model_path, config) +model.compile(compiled_path) + +# Reload from compiled artifacts +model = NeuronQwen35ForCausalLM(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") +gen_config = GenerationConfig( + do_sample=True, top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +inputs = tokenizer("The capital of France is", return_tensors="pt") +gen_model = HuggingFaceGenerationAdapter(model) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=50, +) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +### Vision-Language (trn2.3xlarge, TP=4) + +The VL pipeline uses the text decoder on Neuron and the vision encoder on CPU: + +```python +from src.modeling_qwen35_vl import NeuronQwen35VLForCausalLM, Qwen35VLInferenceConfig + +vl_model = NeuronQwen35VLForCausalLM( + model_path="/path/to/Qwen3.5-27B", + config=vl_config, +) +vl_model.compile(compiled_path) +vl_model.load(compiled_path) + +# See test/integration/test_model.py for full VL usage example +``` + +### DeltaNet Kernel Selection + +The DeltaNet forward path can be controlled via environment variables: + +| Env Var | Forward Path | Use Case | +|---------|-------------|----------| +| `USE_NKI_FUSED=1` | Fused chunked NKI kernel | Best CTE performance (default for SDK 2.29) | +| `USE_NKI_CHUNKED=1` | Per-chunk NKI kernel | Legacy, superseded by fused | +| `USE_NKI=1` | Per-token NKI kernel | TKG (always used for token generation) | +| `DELTANET_SEQUENTIAL=1` | Sequential PyTorch | Debugging/reference | +| *(none)* | PyTorch chunked | Default fallback for CTE | + +## Caveats + +1. **HBM-limited at TP=4:** The 27B model consumes 23.57 GB of the 24 GB HBM per NeuronCore pair (LNC=2). Context length is limited to ~512 tokens. Batch size > 1 not possible. Use trn2.12xlarge (TP=16) for production workloads. + +2. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29). No library modifications needed — runs on stock SDK 2.29 DLAMI (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). + +3. **No mini model test:** Unlike DeepSeek-V3, a mini model cannot be provided because DeltaNet layers require NKI kernels that only execute on Neuron devices. Integration tests require a trn2 instance with the full 27B weights. + +4. **Vision encoder runs on CPU:** The ViT cannot be placed on Neuron because HBM is fully consumed by the text decoder. This adds ~918ms latency per image. Future optimization: quantize text decoder to free HBM, or use larger instance. + +5. **Compilation time:** First compilation takes ~13 minutes. Subsequent compilations with cached NEFFs take ~1 minute. + +6. **+1 RMSNorm convention:** Qwen3.5 uses `output = norm(x) * (1 + weight)` for most RMSNorm layers, but DeltaNet internal norms use standard `output = norm(x) * weight`. The weight conversion handles this automatically, but custom weight loading must be aware of both conventions. + +7. **Neumann series convergence:** The fused DeltaNet kernel uses a 6-round Neumann series for intra-chunk correction. This requires L2-normalized Q and K inputs. Unnormalized inputs will cause NaN divergence. + +## Maximum Sequence Length + +| seq_len | Status | Notes | +|---------|--------|-------| +| 128 | **PASS** | Default, all benchmarks | +| 512 | **PASS** | Compiles and runs, 4 DeltaNet chunks | +| 1024 | **FAIL** | Compiler/runtime OOM (HBM full at TP=4) | + +For seq_len > 512, use trn2.12xlarge or larger instance with TP > 4. + +## Compatibility Matrix + +| Instance | TP | LNC | Status | Notes | +|----------|-----|-----|--------|-------| +| trn2.3xlarge | 4 | 2 | **PASS** | Tested, HBM-limited | +| trn2.12xlarge | 16 | 2 | Expected PASS | Untested, recommended for production | + +### SDK Configuration + +| Component | Version | +|-----------|---------| +| NxDI | 0.9.17334 | +| neuronx-cc | 2.24.5133 | +| torch | 2.9.1 | +| transformers | 4.57.6 | +| NKI | 0.3.0 | +| NXDI venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | + +## Testing + +### Unit Tests (CPU only, no device needed) + +```bash +cd contrib/models/Qwen3.5-27B/ +# On DLAMI: source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pytest test/unit/ -v +``` + +Tests: config parsing (26), weight conversion (16) = **42 tests**. + +### Integration Tests (needs trn2.3xlarge with 4 NeuronCores) + +```bash +cd contrib/models/Qwen3.5-27B/ + +QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-27B \ +QWEN35_COMPILED_PATH=/mnt/models/qwen35_traced \ +pytest test/integration/test_model.py --capture=tee-sys +``` + +Tests: model loads, generates, coherence, top-token valid, capital test, TTFT, throughput, multi-prompt = **8 tests**. + +## Key Porting Challenges + +1. **DeltaNet on Neuron:** No prior NxDI implementation of linear recurrent attention exists. Required writing three custom NKI kernels (recurrent, chunked, fused) with careful SBUF state management and Neumann series approximation for the chunked intra-chunk correction. + +2. **Hybrid state management:** DeltaNet layers maintain per-head (128, 128) recurrent state and (conv_dim, kernel_size-1) conv state, while GQA layers use standard KV cache. Both must be aliased as `input_output_aliases` in the XLA trace for HBM persistence across forward calls. + +3. **Interleaved q_proj:** The HF checkpoint stores Q and gate weights interleaved as `[head0_q | head0_gate | head1_q | head1_gate | ...]`. Must reshape to (num_heads, 2*head_dim, hidden), then split along dim=1. + +4. **Dual RMSNorm conventions:** 48 DeltaNet layers use standard `norm(x) * weight` while all 64 `input_layernorm` / `post_attention_layernorm` and 16 GQA Q/K norms use `norm(x) * (1 + weight)`. Weight conversion must selectively add 1.0 only to the correct subset. + +5. **DeltaNet conv1d state:** Each DeltaNet layer has a causal conv1d (kernel_size=4) that requires 3 previous timesteps. Conv state is stored as an nn.Parameter buffer and aliased for HBM persistence, similar to Mamba's conv state. + +6. **`aten.scatter.src` unsupported:** Neuron compiler does not support `aten.scatter.src`. DeltaNet state updates use `new_state + buffer * 0` pattern instead. + +## Example Checkpoints + +- `Qwen/Qwen3.5-27B` (BF16, ~52 GB, 11 shards) +- `Qwen/Qwen3.5-27B-VL` (BF16, VL variant with ViT) + +## Maintainer + +AWS Neuron + +**Last Updated:** 2026-04-12 diff --git a/contrib/models/Qwen3.5-27B/src/__init__.py b/contrib/models/Qwen3.5-27B/src/__init__.py new file mode 100644 index 00000000..7e79aa03 --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/__init__.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_qwen35 import ( + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35DecoderModelInstance, + Qwen35InferenceConfig, + Qwen35MLP, + Qwen35ModelWrapper, +) +from src.modeling_qwen35_vision import ( + NeuronQwen35VisionForImageEncoding, + NeuronQwen35VisionModel, +) +from src.modeling_qwen35_vl import ( + NeuronQwen35VLForCausalLM, + Qwen35VLInferenceConfig, +) + +__all__ = [ + # Text decoder + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "NeuronQwen35ForCausalLM", + "NeuronQwen35Model", + "Qwen35DecoderModelInstance", + "Qwen35InferenceConfig", + "Qwen35MLP", + "Qwen35ModelWrapper", + # Vision encoder + "NeuronQwen35VisionForImageEncoding", + "NeuronQwen35VisionModel", + # Vision-language + "NeuronQwen35VLForCausalLM", + "Qwen35VLInferenceConfig", +] diff --git a/contrib/models/Qwen3.5-27B/src/modeling_qwen35.py b/contrib/models/Qwen3.5-27B/src/modeling_qwen35.py new file mode 100644 index 00000000..86e6fb4c --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/modeling_qwen35.py @@ -0,0 +1,2493 @@ +""" +NxDI contrib: Qwen3.5-27B (qwen3_5 -- dense model) + +Hybrid DeltaNet + Standard Attention + Dense MLP architecture. +Adapted from Qwen3.5-35B-A3B (MoE) -- MoE removed, dense MLP added. + +48 of 64 layers use Gated DeltaNet (linear recurrent attention) +16 of 64 layers use standard GQA with KV cache + output gate +All 64 layers use a dense SwiGLU MLP (intermediate_size=17408) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples +""" + +import gc +import math +import logging +import os +import sys +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.utils import cpu_mode + +try: + from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29) +except ImportError: + from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28) +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state, +) +from src.nki_kernels.nki_deltanet_chunked import ( + deltanet_chunk_step as _deltanet_nki_chunk_step, +) +from src.nki_kernels.nki_deltanet_fused import ( + deltanet_fused_chunked_fwd as _deltanet_fused_kernel, +) +from src.nki_kernels.nki_deltanet_fused import ( + _make_lower_mask, + _make_lower_mask_diag, + _make_identity, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + +# Option B: Direct nkilib flash attention for head_dim > 128 +USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1" + +_nkilib_flash_attn = None +if USE_NKILIB_KERNEL: + try: + import neuronxcc.nki as _nki + from neuronx_distributed_inference.modules.attention.attention_base import ( + peel_decorations as _peel_decorations, + get_platform_target as _get_platform_target, + ) + from neuronxcc.nki.compiler import ( + skip_middle_end_transformations as _skip_middle_end, + enable_stack_allocator as _enable_stack_allocator, + ) + + import importlib + + _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src" + if os.path.isdir(_fork_path) and _fork_path not in sys.path: + sys.path.insert(0, _fork_path) + _to_remove = [k for k in sys.modules if k.startswith("nkilib")] + for k in _to_remove: + del sys.modules[k] + import nki.language as _stub_nl + import neuronxcc.nki.language as _real_nl + + for _attr in [ + "NKIObject", + "float8_e4m3fn", + "float8_e4m3fn_x4", + "float8_e5m2_x4", + "float4_e2m1fn_x4", + ]: + if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr): + setattr(_real_nl, _attr, getattr(_stub_nl, _attr)) + from nkilib.core.attention.attention_cte import ( + attention_cte as _attention_cte_raw, + _MAX_HEAD_DIM, + ) + + assert _MAX_HEAD_DIM == 256, ( + f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. " + f"System nkilib may have been loaded instead of fork." + ) + logger.info( + f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})" + ) + + _raw_fn = _peel_decorations(_attention_cte_raw) + _platform = _get_platform_target() + _nkilib_flash_attn = _nki.jit( + _raw_fn, + mode="torchxla", + platform_target=_platform, + show_compiler_tb=True, + debug_kernel=True, + ) + _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn) + _nkilib_flash_attn = _enable_stack_allocator( + _nkilib_flash_attn, log_level=logging.INFO + ) + logger.info("Option B: nkilib flash attention loaded for head_dim > 128") + except Exception as e: + logger.warning(f"Option B: Failed to load nkilib flash attention: {e}") + import traceback as _tb + + _tb.print_exc() + _nkilib_flash_attn = None + +# Option A: Detect if patch_attn_kernel was imported +NKILIB_PATCH_ACTIVE = False +try: + from importlib import import_module as _import_module + + _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd") + if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"): + NKILIB_PATCH_ACTIVE = True + logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel") +except Exception: + pass + + +# ============================================================ +# Newton-Raphson Refined RMSNorm +# ============================================================ +USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1" +USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1" + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy.""" + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + y = torch.rsqrt(variance + self.variance_epsilon) + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode() or USE_PYTHON_RMSNORM: + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 48 of 64 layers in Qwen3.5-27B. + Uses a chunk-based linear recurrence instead of KV cache. + + HF weight layout (27B dense -- scaled dimensions): + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (10240, 5120) + - in_proj_z.weight: (value_dim, hidden_size) = (6144, 5120) + - in_proj_a.weight: (num_v_heads, hidden_size) = (48, 5120) + - in_proj_b.weight: (num_v_heads, hidden_size) = (48, 5120) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (10240, 1, 4) + - A_log: (num_v_heads,) = (48,) + - dt_bias: (num_v_heads,) = (48,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (5120, 6144) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 5120 + self.num_v_heads = tc.linear_num_value_heads # 48 + self.num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + self.key_dim = self.head_k_dim * self.num_k_heads # 2048 + self.value_dim = self.head_v_dim * self.num_v_heads # 6144 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + + # KV cache dummy shape info + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z) + self.conv_dim = self.key_dim * 2 + self.value_dim # 10240 + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # Input projections (nn.Linear — NOT sharded by NxDI TP, replicated on all ranks) + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + # Decay parameters + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + self.A_log = nn.Parameter(torch.zeros(self.num_v_heads)) + + # Output norm and projection + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + # State buffers for CTE -> TKG carry-over + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + new_state = recurrent_state * g_t + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + ) + outputs.append(out_bh) + states.append(state_bh) + + output = torch.stack(outputs, dim=0) + output = output.reshape(B, H, S, v_dim) + + final_state = torch.stack(states, dim=0) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _nki_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Chunked NKI kernel forward for context encoding (prefill).""" + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + num_chunks = total_seq_len // chunk_size + g_reshaped = g.reshape(B, H, num_chunks, chunk_size) + g_cs = g_reshaped.cumsum(dim=-1) + g_last_per_chunk = g_cs[:, :, :, -1:] + g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size) + + query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim) + + beta_chunks = ( + beta.reshape(B, H, num_chunks, chunk_size) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, v_dim) + ) + gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + + BH = B * H + query_chunks = query_chunks.reshape( + BH, num_chunks, chunk_size, k_dim + ).contiguous() + key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous() + value_chunks = value_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + beta_chunks = beta_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + + device = query.device + lower_mask = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=-1, + ) + identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device) + lower_mask_diag = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=0, + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device) + + head_chunks = [] + for c_idx in range(num_chunks): + q_chunk = query_chunks[bh, c_idx].contiguous() + k_chunk = key_chunks[bh, c_idx].contiguous() + v_chunk = value_chunks[bh, c_idx].contiguous() + beta_chunk = beta_chunks[bh, c_idx].contiguous() + gc_chunk = gc_chunks[bh, c_idx].contiguous() + gl_chunk = gl_chunks[bh, c_idx].contiguous() + + out_chunk, state = _deltanet_nki_chunk_step( + q_chunk, + k_chunk, + v_chunk, + beta_chunk, + gc_chunk, + gl_chunk, + state, + lower_mask, + identity_mat, + lower_mask_diag, + ) + head_chunks.append(out_chunk) + + head_output = torch.cat(head_chunks, dim=0) + all_outputs.append(head_output) + all_states.append(state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _fused_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Fused single-kernel chunked forward for CTE — SSD-style. + + Processes all chunks in a single NKI kernel call per (B,H) pair. + State persists in SBUF across chunks (no HBM round-trips). + Cumsum of g computed in-kernel via tensor_tensor_scan. + + This is the optimized version of _nki_chunked_forward with: + 1. Single kernel call per (B,H) instead of B*H*num_chunks + 2. State in SBUF across all chunks (biggest perf win) + 3. In-kernel cumsum (avoids PyTorch cumsum overhead) + 4. tensor_scalar for broadcasts (no explicit loops) + """ + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Pad sequence to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + BH = B * H + # Flatten to (BH, S, dim) for per-(b,h) kernel calls + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + + # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + + # Create constant mask tensors (shared across all B*H calls) + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_fused_kernel( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 1) — RAW g, not cumsum + beta_flat[bh], # (S, 1) — sigmoid(b) + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + all_outputs.append(out_bh) + all_states.append(state_bh) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _sequential_forward(self, query, key, value, g, beta, output_final_state=False): + """Sequential full-sequence gated delta rule for CTE. + + Uses the same per-step recurrence as _recurrent_step but loops over the + full sequence. Avoids the slice-assignment loop in _chunk_forward that + may compile incorrectly on Neuron/XLA. + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + state = query.new_zeros(B, H, k_dim, v_dim) + all_outputs = [] + for t in range(S): + q_t = query[:, :, t] # (B, H, K) + k_t = key[:, :, t] # (B, H, K) + v_t = value[:, :, t] # (B, H, V) + beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1) + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + + # Gated delta rule + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + delta = (v_t - kv_mem) * beta_t # (B, H, V) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V) + + o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + all_outputs.append(o_t.unsqueeze(2)) + + output = torch.cat(all_outputs, dim=2) # (B, H, S, V) + final_state = state if output_final_state else None + return output, final_state + + def _chunk_forward(self, query, key, value, g, beta, output_final_state=False): + """Chunk-based forward for context encoding (prefill).""" + chunk_size = 64 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface.""" + batch_size, seq_len, _ = hidden_states.shape + + seq_ids = kwargs.get("seq_ids", None) + is_decode = past_key_value is not None + + # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding. + # Passed from get_model_output where it's computed from input_ids != pad_token_id. + # Embeddings are already zeroed for padding tokens; this mask additionally + # zeros the decay gate so the recurrent state is preserved unchanged + # through padding positions (no spurious decay). + valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32: + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Split QKV + query = qkv[..., : self.key_dim] + key = qkv[..., self.key_dim : self.key_dim * 2] + value = qkv[..., self.key_dim * 2 :] + + # Causal Conv1d on QKV + mixed = torch.cat([query, key, value], dim=-1) + mixed = mixed.transpose(1, 2) + + if is_decode: + if seq_ids is not None: + conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) + else: + conv_state = self.conv_state_buffer[:batch_size] + conv_input = torch.cat([conv_state, mixed], dim=-1) + + w = self.conv1d.weight.squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(4): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + mixed_post_conv = F.silu(self.conv1d(mixed)[:, :, :seq_len]) + + if valid_mask_1d is not None: + # valid_mask_1d is [B, S, 1]; count valid tokens per batch + num_valid = ( + valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + ) # [B, 1] + idx_base = num_valid - 3 + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(3, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets # [B, 3] + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) + else: + new_conv_state = mixed[:, :, -3:].contiguous() + + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 = direct replacement + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose(1, 2) + + # Zero out conv1d output for padding positions. + # Conv1d with kernel_size=4 leaks real token info into the first + # few padding positions. Zeroing here ensures Q, K, V are exactly + # zero for all padding positions so the recurrence is unaffected. + if valid_mask_1d is not None: + mixed_post_conv = ( + mixed_post_conv * valid_mask_1d + ) # [B, S, conv_dim] * [B, S, 1] + + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute gating + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + if valid_mask_1d is not None: + # Zero g for padding → alpha=exp(0)=1 → state preserved through padding + # Zero beta for padding → no state update from padding tokens + mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S] + g = g * mask_2d.unsqueeze(-1) + beta = beta * mask_2d.unsqueeze(-1) + + # Expand K heads to match V heads (16 -> 48) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 3 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # TKG: single-step recurrent update + if seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ).float() + else: + recurrent_state = self.recurrent_state_buffer[:batch_size].float() + + output, new_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state + ) + new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + new_state_bf16, + self.recurrent_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + else: + # CTE: fused, chunk, NKI, or sequential forward + use_nki_fused = os.environ.get("USE_NKI_FUSED") == "1" + use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1" + use_nki = os.environ.get("USE_NKI") == "1" + use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1" + + if use_nki_fused: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_chunked: + output, final_state = self._nki_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki: + output, final_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + elif use_sequential: + output, final_state = self._sequential_forward( + query, key, value, g, beta, output_final_state=True + ) + else: + output, final_state = self._chunk_forward( + query, key, value, g, beta, output_final_state=True + ) + + if final_state is not None: + final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + final_state_bf16, + torch.zeros( + alloc_bs - batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=final_state_bf16.dtype, + device=final_state_bf16.device, + ), + ], + dim=0, + ) + new_rec_state = new_rec_state + self.recurrent_state_buffer * 0 + else: + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + else: + new_rec_state = self.recurrent_state_buffer * 1 + + # Output: norm, gate, project + output = output.to(hidden_states.dtype) + output = output.transpose(1, 2).contiguous() + output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = self.norm(output) + z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = output * F.silu(z_gate) + output = output.reshape(batch_size, seq_len, self.value_dim) + output = self.out_proj(output) + + # Return dummy KV for KVCacheManager + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + return output, (dummy_k, dummy_v), new_rec_state, new_conv_state + + +# ============================================================ +# InferenceConfig (Dense -- no MoE) +# ============================================================ + + +class Qwen35InferenceConfig(InferenceConfig): + """Config for Qwen3.5-27B (dense) with hybrid DeltaNet + Attention.""" + + def __init__(self, *args, **kwargs): + # Set defaults BEFORE super().__init__() because it calls validate_config() + # which checks get_required_attributes(). These can be overridden by + # kwargs or load_config. + + # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] x 16 = 64 layers + if "layer_types" not in kwargs and not any( + hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__") + ): + layer_types = [] + for _ in range(16): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + kwargs.setdefault("layer_types", layer_types) + + # DeltaNet-specific config defaults + kwargs.setdefault("linear_num_value_heads", 48) + kwargs.setdefault("linear_num_key_heads", 16) + kwargs.setdefault("linear_key_head_dim", 128) + kwargs.setdefault("linear_value_head_dim", 128) + kwargs.setdefault("linear_conv_kernel_dim", 4) + + super().__init__(*args, **kwargs) + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Standard HF config attributes expected by NxDI + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return NeuronConfig + + +# ============================================================ +# Attention (standard GQA for 16 of 64 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + """ + + def __init__(self, config): + super().__init__() + self.head_dim = config.head_dim # 256 + self.rope_dim = config.rope_dim # 64 + self.mrope_section = config.mrope_section # [11, 11, 10] + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + self.rope_theta = config.rope_theta + + # Validate mrope_section sums to rope_dim // 2 = 32 + assert sum(self.mrope_section) == self.rope_dim // 2, ( + f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, " + f"expected {self.rope_dim // 2}" + ) + + def forward(self, x, position_ids_3d): + """Compute cos/sin from 3D position IDs. + + Args: + x: hidden_states (for device/dtype inference) + position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions + + Returns: + cos: (batch_size, seq_len, rope_dim) + sin: (batch_size, seq_len, rope_dim) + """ + device = x.device + dtype = torch.float32 + + sections = self.mrope_section # [11, 11, 10] + cos_parts = [] + sin_parts = [] + + freq_offset = 0 + for axis_idx, section_size in enumerate(sections): + pos = position_ids_3d[axis_idx].float() # (batch, seq_len) + + dim_pairs = section_size # number of (cos, sin) pairs for this axis + freqs = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, dim_pairs * 2, 2, dtype=dtype, device=device) + / (self.rope_dim) + ) + ) # (dim_pairs,) + + # freqs: (dim_pairs,), pos: (B, S) -> angles: (B, S, dim_pairs) + angles = pos.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0) + + cos_parts.append(angles.cos()) + sin_parts.append(angles.sin()) + + # Concatenate: (B, S, 32) + cos = torch.cat(cos_parts, dim=-1) + sin = torch.cat(sin_parts, dim=-1) + + if self.mrope_interleaved: + # Interleave to (B, S, 64): [c0, c0, c1, c1, ...] for rotate_half + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + else: + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + + return cos, sin + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Standard GQA attention for Qwen3.5 with output gate and partial RoPE. + + 24 Q heads, 4 KV heads (6:1 GQA), head_dim=256 for 27B dense. + q_proj is doubled (query + gate), split at load time. + Only first rope_dim=64 of head_dim=256 gets rotary encoding. + + Uses NeuronAttentionBase infrastructure for QKV projection, KV cache, + RoPE, and attention computation. Overrides forward() to insert the + sigmoid output gate between attention output and o_proj. + """ + + def __init__(self, config): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding. + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim + # Populated from the second half of q_proj during state dict conversion. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + bias=False, + gather_output=False, + ) + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + Split Q/K along last dim into: + q_rope (first 64 dims) -- apply RoPE + q_pass (remaining 192 dims) -- pass through unchanged + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # Split into rope and pass-through portions + Q_orig_dtype = Q.dtype + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back (ensure bf16 is maintained) + Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype) + K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype) + + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None): + """Prefill path with NKI flash attention for head_dim=256.""" + head_dim = Q.shape[-1] + + # Option B: nkilib flash attention for head_dim > 128 + if _nkilib_flash_attn is not None: + q_contig = Q.contiguous() + k_contig = K.contiguous() + v_contig = V.contiguous() + scale = 1.0 / math.sqrt(head_dim) + result = _nkilib_flash_attn( + q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True + ) + return result, None + + # Option A: kernel patched globally + if NKILIB_PATCH_ACTIVE: + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + # Fallback: softmax path + if head_dim > 128: + # GQA: expand K/V heads to match Q heads + num_q_heads = Q.shape[1] + num_kv_heads = K.shape[1] + if num_q_heads != num_kv_heads: + kv_rep = num_q_heads // num_kv_heads + K = ( + K.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + V = ( + V.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + attn_weights = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(head_dim) + if attention_mask is not None: + if attention_mask.dtype == torch.bool: + attn_weights = attn_weights.masked_fill(~attention_mask, -65504.0) + else: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + return torch.matmul(attn_weights, V), None + + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + adapter_ids=None, + active_mask=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + rope_pos_ids = position_ids + + # Compute gate from input hidden states (before QKV projection) + gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim) + + # Standard QKV prep (projections, QK norm, RoPE) + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + if past_key_value is None: + # Context encoding (prefill) + attn_output, _flash_strategy = self.perform_prefill( + Q, K, V, q_len, bsz, attention_mask + ) + else: + # Token generation (decode) + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + + # Apply o_proj + attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + # Ensure K, V are in model dtype (bf16) for KV cache update + # (prevents mixed-precision dynamic-update-slice in neuronx-cc) + K = K.to(self.torch_dtype) + V = V.to(self.torch_dtype) + past_key_value = (K, V) + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Dense MLP (replaces MoE) +# ============================================================ + + +class Qwen35MLP(nn.Module): + """Dense SwiGLU MLP for Qwen3.5-27B. + + gate_proj: hidden_size -> intermediate_size (5120 -> 17408) + up_proj: hidden_size -> intermediate_size (5120 -> 17408) + down_proj: intermediate_size -> hidden_size (17408 -> 5120) + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + """ + + def __init__(self, config): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + hidden_states = F.silu(gate) * up + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +# ============================================================ +# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + Uses dense MLP for all layers (no MoE). + """ + + def __init__(self, config: Qwen35InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # Dense MLP (all layers) + self.mlp = Qwen35MLP(config) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # DeltaNet path + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = (new_rec_state, new_conv_state) + else: + deltanet_states = None + # Standard attention path + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Dense MLP FFN + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Model +# ============================================================ + + +class NeuronQwen35Model(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order.""" + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions.""" + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers.""" + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence + # is not polluted. DeltaNet has no attention mask -- it processes all + # sequence positions through a linear recurrence. Padding tokens have + # real embedding vectors which corrupt the recurrence state. + # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding. + deltanet_padding_mask = ( + (input_ids != self.padding_idx).unsqueeze(-1).to(inputs_embeds.dtype) + ) + if is_for_context_encoding: + inputs_embeds = inputs_embeds * deltanet_padding_mask + + # Vision embedding injection + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + if is_for_context_encoding: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG + cache_size = self.n_positions + if not is_for_context_encoding: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] + cos_cache = None + sin_cache = None + + # Convert 2D attention_mask to 4D causal mask for CTE + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + padding_4d = attention_mask[:, None, None, :].to(torch.bool) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) + + # Pre-compute mRoPE cos/sin + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask, + deltanet_padding_mask=deltanet_padding_mask, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + deltanet_state_tensors.append(deltanet_states[0]) + deltanet_state_tensors.append(deltanet_states[1]) + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output.""" + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + pass + else: + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, rank_id, world_size, pad_size=self.lm_head.pad_size + ) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + outputs = [res] + if self.neuron_config.output_logits: + outputs += [logits] + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if hasattr(self, "_deltanet_updated_states"): + outputs += self._deltanet_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter (Dense -- no MoE weight handling) +# ============================================================ + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5-27B weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: same names (no remapping needed) + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (12288, 5120) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + Dense MLP (all layers): + HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names) + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + logger.debug( + f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + logger.warning(f"[NORM FIX] key not found: {nk}") + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (12288, 5120) = (num_heads * head_dim * 2, hidden) + # INTERLEAVED: [head0_query(256) | head0_gate(256) | head1_query(256) | ...] + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + num_heads = config.num_attention_heads # 24 + head_dim = config.head_dim # 256 + q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size) + query_w = q_proj_w[:, :head_dim, :] # (24, 256, 5120) + gate_w = q_proj_w[:, head_dim:, :] # (24, 256, 5120) + query_w = query_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (6144, 5120) + gate_w = gate_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (6144, 5120) + + neuron_state_dict[q_proj_key] = query_w + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + if q_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + neuron_state_dict[q_key], + neuron_state_dict[k_key], + neuron_state_dict[v_key], + ] + ) + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # Dense MLP: no weight conversion needed -- HF and NxDI use same names + # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases.""" + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start_idx = num_output_from_trace + num_kv + + if hasattr(module, "_deltanet_state_params"): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper for VL support with mRoPE and vision inputs.""" + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask.""" + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = n_active_tokens > 1 + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + padded = list(bucket_inputs) + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + padded.append(mrope_position_ids) # position 21 + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size.""" + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = padded_seq_len > 1 + + if is_cte: + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + orig_len = current_mrope.shape[-1] + pad_size = padded_seq_len - orig_len + last_pos = current_mrope[:, :, -1:] + pad_offsets = torch.arange( + 1, pad_size + 1, dtype=current_mrope.dtype + ) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + vision_mask = current_vis_mask[:, :padded_seq_len] + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35Model + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5ForConditionalGeneration) but we + only need the text backbone. + """ + from transformers import AutoModelForCausalLM + + kwargs.setdefault("trust_remote_code", True) + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format.""" + new_sd = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU.""" + super()._copy_past_key_values(outputs) + + num_output_from_trace = 1 + if ( + self.neuron_config.output_logits + and self.neuron_config.on_device_sampling_config + ): + num_output_from_trace = 2 + + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start = num_output_from_trace + num_kv + + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs for HF generation loop.""" + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + ): + """Override to pass all 24 positional args explicitly.""" + is_prefill = self._is_prefill(position_ids) + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + empties = [torch.empty(0) for _ in range(14)] + + if self._is_prefill(position_ids): + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids + + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + pad_seq = torch.arange( + batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + chunk_out = self.context_encoding_model( + chunk_input_ids, + chunk_attn_mask, + chunk_pos_ids, + chunk_seq_ids, + chunk_sampling, + chunk_prev_hidden, + chunk_adapter_ids, + *empties, + chunk_mrope, + chunk_vis_emb, + chunk_vis_mask, + ) + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + *empties, + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + ) + return compiler_args diff --git a/contrib/models/Qwen3.5-27B/src/modeling_qwen35_vision.py b/contrib/models/Qwen3.5-27B/src/modeling_qwen35_vision.py new file mode 100644 index 00000000..08557cc8 --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/modeling_qwen35_vision.py @@ -0,0 +1,818 @@ +""" +Qwen3.5-27B (Dense) Vision Encoder for NeuronX Distributed Inference. + +Ports the Qwen3.5 ViT encoder to run on Neuron. The vision encoder +architecture is identical to the MoE variant (same patch embed, same rotary, +same merger) -- only out_hidden_size changes (5120 vs 2048, read from config). + +The vision encoder runs as a separate compiled model from the text decoder, +compiled and loaded via NeuronBaseForImageToText. +""" + +import logging +import math +import os +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# CRITICAL: Use finite negative value instead of -inf for Neuron attention masks. +# The Neuron compiler's bfloat16 handling of -inf produces NaN that bleeds from +# padding positions into ALL positions through the transformer layers. +# -65504.0 is large enough for softmax masking but avoids NaN overflow. +_MASK_NEG_INF = -65504.0 + +logger = logging.getLogger(__name__) + +# -- NxDI imports (available on Neuron instances) -- +try: + from neuronx_distributed_inference.models.application_base import ( + NeuronApplicationBase, + ) + from neuronx_distributed_inference.models.model_wrapper import ModelWrapper + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + ) + from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding + from neuronx_distributed.parallel_layers import layers as nxd_layers +except ImportError: + logger.warning( + "NxDI imports unavailable -- vision module can only be used on Neuron instances" + ) + +# -- HuggingFace imports for patch embed (runs on CPU) -- +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeVisionPatchEmbed, + Qwen3_5MoeVisionPatchMerger, + Qwen3_5MoeVisionRotaryEmbedding, + ) +except ImportError: + try: + # transformers 4.57+ uses Qwen3VL* class names + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen3VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen3VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + try: + # Older transformers uses Qwen2VL* class names + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen2VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen2VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + Qwen3_5MoeVisionPatchEmbed = None + Qwen3_5MoeVisionPatchMerger = None + Qwen3_5MoeVisionRotaryEmbedding = None + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + """Apply rotary position embeddings to vision Q and K tensors. + + Uses rotate_half style (matching HF reference): + q_embed = (q * cos) + (rotate_half(q) * sin) + + Args: + q: (seq_len, num_heads, head_dim) + k: (seq_len, num_heads, head_dim) + cos: (seq_len, head_dim) + sin: (seq_len, head_dim) + """ + cos = cos.unsqueeze(-2) # (seq_len, 1, head_dim) + sin = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class NeuronQwen35VisionAttention(nn.Module): + """Vision attention for Qwen3.5 MoE. + + Uses fused QKV linear (no bias in Neuron port for efficiency). + Non-causal attention with block-diagonal mask for variable-length images. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.hidden_size // self.num_heads + self.scaling = self.head_dim**-0.5 + + # Fused QKV: (hidden_size -> 3 * hidden_size) with bias + self.qkv = nxd_layers.ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + bias=True, + gather_output=True, + ) + self.proj = nxd_layers.RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=False, + ) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple + """ + seq_len = hidden_states.shape[0] + + # QKV projection + qkv = self.qkv(hidden_states) # (seq_len, 3 * hidden_size) + qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(1, 0, 2, 3) # (3, seq_len, num_heads, head_dim) + q, k, v = qkv.unbind(0) # each (seq_len, num_heads, head_dim) + + # Apply rotary embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + # Reshape for batched attention: (1, num_heads, seq_len, head_dim) + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + # Scaled dot-product attention + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * self.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + # Reshape back: (seq_len, hidden_size) + attn_output = attn_output.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + + # Output projection + attn_output = self.proj(attn_output) + return attn_output + + +class NeuronQwen35VisionMLP(nn.Module): + """Vision MLP with GELU activation.""" + + def __init__(self, config): + super().__init__() + self.linear_fc1 = nxd_layers.ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + gather_output=True, + ) + self.linear_fc2 = nxd_layers.RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + input_is_parallel=False, + ) + self.act_fn = nn.GELU() + + def forward(self, hidden_states): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_states))) + + +class NeuronQwen35VisionBlock(nn.Module): + """Single vision transformer block: LayerNorm + Attention + LayerNorm + MLP.""" + + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = NeuronQwen35VisionAttention(config) + self.mlp = NeuronQwen35VisionMLP(config) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class NeuronQwen35VisionModel(nn.Module): + """Qwen3.5 MoE Vision Encoder for Neuron. + + This is the nn.Module that gets compiled and traced onto Neuron. + Patch embedding, positional embedding, and rotary embedding are computed + on CPU in the ModelWrapper and passed as inputs. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [NeuronQwen35VisionBlock(config) for _ in range(config.depth)] + ) + # Merger: spatial_merge_size^2 * hidden_size -> out_hidden_size + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) -- after patch_embed + pos_embed + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple for rotary + + Returns: + vision_embeddings: (merged_seq_len, out_hidden_size) + """ + for block in self.blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + + # Apply merger: norm -> spatial merge -> fc1 -> gelu -> fc2 + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + + return hidden_states + + +class CPUVisionModel(nn.Module): + """CPU-only vision encoder (pure PyTorch, no Neuron dependencies). + + Used when HBM is insufficient to load the vision encoder on Neuron + alongside the text decoder (e.g., 27B dense model on trn2.3xlarge). + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [self._make_block(config) for _ in range(config.depth)] + ) + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + @staticmethod + def _make_block(config): + """Build a single vision block with standard nn.Linear (no TP).""" + block = nn.Module() + block.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + block.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + + # Attention + attn = nn.Module() + attn.hidden_size = config.hidden_size + attn.num_heads = config.num_heads + attn.head_dim = config.hidden_size // config.num_heads + attn.scaling = attn.head_dim**-0.5 + attn.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True) + attn.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) + block.attn = attn + + # MLP + mlp = nn.Module() + mlp.linear_fc1 = nn.Linear( + config.hidden_size, config.intermediate_size, bias=True + ) + mlp.linear_fc2 = nn.Linear( + config.intermediate_size, config.hidden_size, bias=True + ) + mlp.act_fn = nn.GELU() + block.mlp = mlp + + return block + + def _forward_attention(self, attn, hidden_states, attention_mask, cos, sin): + seq_len = hidden_states.shape[0] + qkv = attn.qkv(hidden_states).reshape(seq_len, 3, attn.num_heads, attn.head_dim) + qkv = qkv.permute(1, 0, 2, 3) + q, k, v = qkv.unbind(0) + + if cos is not None and sin is not None: + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q = (q * cos_u) + (rotate_half(q) * sin_u) + k = (k * cos_u) + (rotate_half(k) * sin_u) + + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * attn.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_weights, v) + out = out.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + return attn.proj(out) + + def forward(self, hidden_states, attention_mask, cos, sin): + for block in self.blocks: + hidden_states = hidden_states + self._forward_attention( + block.attn, block.norm1(hidden_states), attention_mask, cos, sin + ) + hidden_states = hidden_states + block.mlp.linear_fc2( + block.mlp.act_fn(block.mlp.linear_fc1(block.norm2(hidden_states))) + ) + + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + return hidden_states + + +class NeuronQwen35VisionModelWrapper(ModelWrapper): + """Wraps the vision encoder for NxDI tracing. + + Handles CPU-side operations that cannot be traced: + - Patch embedding (Conv3d) + - Positional embedding (Embedding + bilinear interpolation) + - Rotary position embedding computation + - Vision attention mask construction (block-diagonal) + - Sequence length bucketing and padding/unpadding + + Supports three modes: + 1. NxDI traced model (parallel layers) -- standard NxDI compilation + 2. Pre-compiled standalone model -- loaded from torch_neuronx.trace() output + 3. CPU-only model -- for when HBM is full (e.g., 27B dense on trn2.3xlarge) + """ + + def __init__(self, config, model_cls=None, **kwargs): + if model_cls is not None: + super().__init__(config, model_cls, **kwargs) + else: + # Standalone mode: no NxDI model_cls + nn.Module.__init__(self) + self.vision_config = config + self._compiled_model = None # Set by load_compiled() -- single bucket + self._compiled_buckets = None # Set by load_compiled() -- multi-bucket dict + self._cpu_model = None # Set by load_cpu_model() + + # These HF modules run on CPU, outside the traced graph + if Qwen3_5MoeVisionPatchEmbed is not None: + self.patch_embed = Qwen3_5MoeVisionPatchEmbed(config) + self.pos_embed = nn.Embedding( + config.num_position_embeddings, config.hidden_size + ) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2) + else: + logger.warning("HF Qwen3.5 MoE vision classes not available") + + self.vision_seq_len_buckets = kwargs.get( + "vision_seq_len_buckets", [1024, 4096, 16384] + ) + + def load_compiled(self, compiled_model_path): + """Load pre-compiled standalone vision encoder(s). + + Supports two modes: + 1. Single .pt file: Legacy mode, loads one compiled model for one bucket size. + 2. Directory with multiple .pt files: Multi-bucket mode. Files must be named + 'vision_encoder_{bucket_size}.pt' (e.g., 'vision_encoder_256.pt'). + Falls back to single 'vision_encoder.pt' in the directory. + + Args: + compiled_model_path: Path to a .pt file or directory containing bucket .pt files. + """ + import glob as glob_module + + logger.info(f"Loading pre-compiled vision encoder from {compiled_model_path}") + + if os.path.isfile(compiled_model_path): + # Single file mode (legacy) + self._compiled_model = torch.jit.load(compiled_model_path) + self._compiled_buckets = None + logger.info("Vision encoder loaded successfully (single bucket)") + elif os.path.isdir(compiled_model_path): + # Directory mode: look for bucket-specific files + bucket_files = sorted( + glob_module.glob( + os.path.join(compiled_model_path, "vision_encoder_*.pt") + ) + ) + if bucket_files: + self._compiled_buckets = {} + for bf in bucket_files: + # Extract bucket size from filename: vision_encoder_256.pt -> 256 + basename = os.path.basename(bf) + try: + bucket_size = int( + basename.replace("vision_encoder_", "").replace(".pt", "") + ) + self._compiled_buckets[bucket_size] = torch.jit.load(bf) + logger.info(f" Loaded vision bucket {bucket_size} from {bf}") + except ValueError: + logger.warning(f" Skipping unrecognized file: {bf}") + self._compiled_model = None + # Update vision_seq_len_buckets to match compiled buckets + self.vision_seq_len_buckets = sorted(self._compiled_buckets.keys()) + logger.info( + f"Vision encoder loaded with {len(self._compiled_buckets)} buckets: " + f"{self.vision_seq_len_buckets}" + ) + else: + # Fall back to single vision_encoder.pt in directory + single_path = os.path.join(compiled_model_path, "vision_encoder.pt") + if os.path.exists(single_path): + self._compiled_model = torch.jit.load(single_path) + self._compiled_buckets = None + logger.info( + "Vision encoder loaded successfully (single file in dir)" + ) + else: + raise FileNotFoundError( + f"No vision encoder files found in {compiled_model_path}" + ) + else: + raise FileNotFoundError( + f"Vision encoder path not found: {compiled_model_path}" + ) + + def load_vision_weights_from_hf(self, model_path): + """Load patch_embed and pos_embed weights from HF safetensors. + + Args: + model_path: Path to HF model directory + """ + from pathlib import Path + from safetensors import safe_open + + st_files = sorted( + p + for p in Path(model_path).glob("*.safetensors") + if p.suffix == ".safetensors" + ) + loaded = 0 + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key == "model.visual.patch_embed.proj.weight": + self.patch_embed.proj.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.patch_embed.proj.bias": + self.patch_embed.proj.bias.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.pos_embed.weight": + self.pos_embed.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + logger.info(f"Loaded {loaded} CPU-side vision weight tensors from HF") + + def load_cpu_model(self, model_path): + """Load a CPU-only vision encoder from HF safetensors. + + Use this when HBM is insufficient for the Neuron-compiled vision encoder + (e.g., 27B dense model fills trn2.3xlarge HBM). + + Args: + model_path: Path to HF model directory with safetensors + """ + from pathlib import Path + from safetensors import safe_open + + config = self.vision_config + cpu_model = CPUVisionModel(config) + + # Build key mapping from HF safetensors to CPU model + key_map = {} + for i in range(config.depth): + hf_pre = f"model.visual.blocks.{i}" + loc_pre = f"blocks.{i}" + for suffix in [ + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "mlp.linear_fc1.weight", + "mlp.linear_fc1.bias", + "mlp.linear_fc2.weight", + "mlp.linear_fc2.bias", + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + ]: + key_map[f"{hf_pre}.{suffix}"] = f"{loc_pre}.{suffix}" + + key_map["model.visual.merger.norm.weight"] = "merger_norm.weight" + key_map["model.visual.merger.norm.bias"] = "merger_norm.bias" + key_map["model.visual.merger.linear_fc1.weight"] = "merger_fc1.weight" + key_map["model.visual.merger.linear_fc1.bias"] = "merger_fc1.bias" + key_map["model.visual.merger.linear_fc2.weight"] = "merger_fc2.weight" + key_map["model.visual.merger.linear_fc2.bias"] = "merger_fc2.bias" + + st_files = sorted(Path(model_path).glob("model*.safetensors")) + loaded = 0 + state_dict = cpu_model.state_dict() + + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key in key_map: + local_key = key_map[key] + if local_key in state_dict: + state_dict[local_key].copy_(f.get_tensor(key)) + loaded += 1 + + cpu_model.load_state_dict(state_dict) + cpu_model = cpu_model.to(torch.bfloat16).eval() + self._cpu_model = cpu_model + logger.info( + f"Loaded CPU vision encoder: {loaded} weights, " + f"{sum(p.numel() for p in cpu_model.parameters()) / 1e6:.1f}M params" + ) + + def _get_vision_bucket(self, seq_len): + """Find the smallest bucket that fits the sequence length.""" + for bucket in sorted(self.vision_seq_len_buckets): + if seq_len <= bucket: + return bucket + return self.vision_seq_len_buckets[-1] + + def rot_pos_emb(self, grid_thw): + """Compute rotary positional embeddings for vision tokens. + + Returns: (total_tokens, head_dim) tensor of rotary frequencies. + """ + merge_size = self.vision_config.spatial_merge_size + grid_thw_list = grid_thw.tolist() + + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + freq_table = self.rotary_pos_emb(max_hw) + device = freq_table.device + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + row_idx = ( + block_rows[:, None, None, None] * merge_size + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge_size + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + """Bilinear interpolation of positional embeddings for variable resolution.""" + grid_thw_list = grid_thw.tolist() + grid_ts = [row[0] for row in grid_thw_list] + grid_hs = [row[1] for row in grid_thw_list] + grid_ws = [row[2] for row in grid_thw_list] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + merge_size = self.vision_config.spatial_merge_size + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + + return torch.cat(patch_pos_embeds_permute) + + def _build_vision_attention_mask(self, grid_thw, seq_len, dtype): + """Build block-diagonal attention mask for variable-length images. + + Each image gets its own attention block (no cross-image attention). + """ + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Build block-diagonal mask + mask = torch.full((seq_len, seq_len), _MASK_NEG_INF, dtype=dtype) + for i in range(len(cu_seqlens) - 1): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + mask[start:end, start:end] = 0.0 + + return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) + + def forward(self, pixel_values, image_grid_thw): + """Run vision encoding (CPU preprocessing + Neuron traced model). + + Args: + pixel_values: Raw pixel values from HF processor + image_grid_thw: (num_images, 3) -- temporal, height, width in patches + + Returns: + vision_embeddings: (total_merged_tokens, out_hidden_size) + """ + # 1. Patch embedding (CPU, Conv3d) + hidden_states = self.patch_embed(pixel_values) + + # 2. Positional embedding (CPU, bilinear interpolation) + pos_embeds = self.fast_pos_embed_interpolate(image_grid_thw) + hidden_states = hidden_states + pos_embeds + + # 3. Rotary position embeddings (CPU) + rotary_pos_emb = self.rot_pos_emb(image_grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + # 4. Vision attention mask (block-diagonal) + seq_len = hidden_states.shape[0] + attention_mask = self._build_vision_attention_mask( + image_grid_thw, seq_len, hidden_states.dtype + ) + + # 5. Bucket and pad for Neuron compilation + bucket_len = self._get_vision_bucket(seq_len) + cos, sin = position_embeddings + if seq_len < bucket_len: + pad_len = bucket_len - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + cos = F.pad(cos, (0, 0, 0, pad_len)) + sin = F.pad(sin, (0, 0, 0, pad_len)) + # Extend mask with _MASK_NEG_INF for padded positions (NOT -inf, which causes NaN on Neuron) + mask = torch.full( + (1, 1, bucket_len, bucket_len), _MASK_NEG_INF, dtype=hidden_states.dtype + ) + mask[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask + + # 6. Run vision model (Neuron compiled or CPU fallback) + if self._compiled_buckets is not None: + # Multi-bucket mode: select the compiled model for this bucket + if bucket_len not in self._compiled_buckets: + raise RuntimeError( + f"No compiled vision encoder for bucket size {bucket_len}. " + f"Available buckets: {sorted(self._compiled_buckets.keys())}. " + f"Input seq_len={seq_len} requires bucket {bucket_len}." + ) + compiled_model = self._compiled_buckets[bucket_len] + vision_output = compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._compiled_model is not None: + # Single compiled model (legacy) + vision_output = self._compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._cpu_model is not None: + # CPU-only mode: run vision encoder on CPU (no bucketing/padding needed + # but we pad anyway for consistency with the same merger math) + with torch.no_grad(): + vision_output = self._cpu_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + else: + # NxDI traced model: takes (hidden_states, attention_mask, position_embeddings) + vision_output = self.model(hidden_states, attention_mask, (cos, sin)) + + # 7. Unpad: only keep valid merged tokens + merge_area = self.vision_config.spatial_merge_size**2 + total_merged_tokens = sum( + t + * (h // self.vision_config.spatial_merge_size) + * (w // self.vision_config.spatial_merge_size) + for t, h, w in image_grid_thw.tolist() + ) + vision_output = vision_output[:total_merged_tokens] + + return vision_output + + +class NeuronQwen35VisionForImageEncoding(NeuronApplicationBase): + """Standalone application class for vision encoding (for testing).""" + + model_cls = NeuronQwen35VisionModel + model_wrapper_cls = NeuronQwen35VisionModelWrapper + + @staticmethod + def prepare_input_args(image_path, processor): + """Prepare vision inputs from an image path. + + Args: + image_path: Path to image file + processor: HF AutoProcessor + + Returns: + pixel_values, image_grid_thw + """ + from PIL import Image + + image = Image.open(image_path).convert("RGB") + inputs = processor(images=image, return_tensors="pt") + return inputs["pixel_values"], inputs["image_grid_thw"] diff --git a/contrib/models/Qwen3.5-27B/src/modeling_qwen35_vl.py b/contrib/models/Qwen3.5-27B/src/modeling_qwen35_vl.py new file mode 100644 index 00000000..8526833f --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/modeling_qwen35_vl.py @@ -0,0 +1,662 @@ +""" +Qwen3.5-27B Vision-Language Model Orchestrator for NeuronX Distributed Inference. + +This is the top-level VL model that wires together: +- The vision encoder (modeling_qwen35_vision.py) +- The text decoder (modeling_qwen35.py, dense model with vision injection) + +It handles: +- Multimodal RoPE (mRoPE) with interleaved layout +- Vision embedding injection via scatter_by_index_put +- Separate compilation and loading of vision and text models +- The CTE+TKG generation loop with vision inputs + +Architecture follows the NxDI NeuronBaseForImageToText pattern established +by Qwen3-VL in SDK 2.28, adapted for Qwen3.5 dense model's unique features: +- No deepstack (Qwen3.5 does not use intermediate vision feature injection) +- DeltaNet linear attention layers in the text decoder +- Dense SwiGLU MLP layers in the text decoder +- Interleaved mRoPE (THWTHW... layout) instead of Qwen3-VL's section-based layout +""" + +import logging +import os +from typing import Optional + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# NxDI imports +try: + from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, + ) + from neuronx_distributed_inference.models.config import NeuronConfig + + HAS_NXDI_VL = True +except ImportError: + HAS_NXDI_VL = False + logger.warning("NxDI VL base classes not available -- VL model requires SDK 2.28+") + +# Local imports +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from src.modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) +except ImportError: + from modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) + + +def get_rope_index( + input_ids, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + spatial_merge_size=2, +): + """Compute 3D multimodal RoPE position IDs for Qwen3.5. + + Returns position_ids of shape (3, batch_size, seq_len) where: + - Axis 0: temporal position + - Axis 1: height position + - Axis 2: width position + + For text tokens, all 3 axes have the same sequential position. + For vision tokens, each axis encodes the spatial/temporal grid position. + + Also returns rope_deltas for use during TKG decoding. + + Adapted from HuggingFace Qwen3_5Model.get_rope_index(). + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + image_grid_thw_list = ( + image_grid_thw.tolist() if image_grid_thw is not None else None + ) + video_grid_thw_list = ( + video_grid_thw.tolist() if video_grid_thw is not None else None + ) + + mrope_position_deltas = [] + total_input_ids = input_ids + + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, ids in enumerate(total_input_ids): + ids = ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + if len(vision_start_indices) > 0: + vision_tokens = ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = ids.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = image_grid_thw_list[image_index] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw_list[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t = t + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + +class Qwen35VLInferenceConfig: + """Configuration for the full VL model (text + vision). + + Wraps the existing Qwen35InferenceConfig for text and adds + vision-specific settings. + """ + + def __init__( + self, + text_config, + vision_config, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + spatial_merge_size=2, + vision_seq_len_buckets=None, + **kwargs, + ): + """ + Args: + text_config: Qwen35InferenceConfig instance for the text decoder + vision_config: dict with vision encoder hyperparams (depth, hidden_size, etc.) + image_token_id: Token ID for image placeholder tokens + video_token_id: Token ID for video placeholder tokens + vision_start_token_id: Token ID for <|vision_start|> + vision_end_token_id: Token ID for <|vision_end|> + spatial_merge_size: How many patches are merged (2 = 2x2 = 4 patches merged) + vision_seq_len_buckets: List of vision sequence length buckets for compilation + """ + self.text_config = text_config + self.vision_config = vision_config + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.spatial_merge_size = spatial_merge_size + self.vision_seq_len_buckets = vision_seq_len_buckets or [1024, 4096, 16384] + + +class NeuronQwen35VLForCausalLM: + """Top-level VL model for Qwen3.5-27B on Neuron. + + This class manages: + - Separate compilation/loading of vision encoder and text decoder + - CPU-side mRoPE computation + - Vision embedding injection into text decoder + - The CTE+TKG generation loop + + Note: This is NOT an NeuronBaseForImageToText subclass because the + text decoder (NeuronQwen35ForCausalLM) has extensive custom overrides + (DeltaNet state management, custom forward, custom ModelWrapper) that + don't fit the base class pattern. Instead, this class composes the two + models and handles the VL orchestration directly. + """ + + def __init__(self, model_path, text_config, vision_config=None, processor=None): + """ + Args: + model_path: Path to HF model directory + text_config: Qwen35InferenceConfig for text decoder + vision_config: Qwen35VLInferenceConfig (or None for text-only) + processor: HF AutoProcessor for image preprocessing + """ + self.model_path = model_path + self.text_config = text_config + self.vl_config = vision_config + self.processor = processor + + # Text decoder (existing implementation) + self.text_model = NeuronQwen35ForCausalLM( + model_path=model_path, config=text_config + ) + + # Vision encoder (lazy init -- only built if vl_config provided) + self.vision_model_wrapper = None + if vision_config is not None: + self._init_vision_model(vision_config) + + # mRoPE state + self.rope_deltas = None + + def _init_vision_model(self, vl_config): + """Initialize the vision encoder wrapper.""" + from types import SimpleNamespace + + vision_cfg = SimpleNamespace(**vl_config.vision_config) + self.vision_model_wrapper = NeuronQwen35VisionModelWrapper( + config=vision_cfg, + model_cls=None, # Standalone mode (no NxDI parallel layers) + vision_seq_len_buckets=vl_config.vision_seq_len_buckets, + ) + self._vl_config = vl_config + + def compile(self, compiled_model_path): + """Compile both text and vision models. + + For the vision encoder, use compile_vision_encoder.py separately + (standalone torch_neuronx.trace compilation). Then use load() to + load the pre-compiled vision encoder. + """ + # Compile text decoder + text_path = os.path.join(compiled_model_path, "text_model") + os.makedirs(text_path, exist_ok=True) + self.text_model.compile(text_path) + + # Vision encoder is compiled separately via compile_vision_encoder.py + if self.vision_model_wrapper is not None: + logger.info( + "Vision encoder must be compiled separately using " + "compile_vision_encoder.py. Use load() to load the " + "pre-compiled vision encoder." + ) + + def load(self, compiled_model_path, vision_compiled_path=None): + """Load both compiled models. + + Args: + compiled_model_path: Path to compiled text model (or parent dir) + vision_compiled_path: Path to compiled vision encoder .pt file. + If None, looks for 'vision_encoder.pt' in compiled_model_path. + """ + text_path = os.path.join(compiled_model_path, "text_model") + if os.path.exists(text_path): + self.text_model.load(text_path) + else: + # Backward compatibility: text model compiled at root + self.text_model.load(compiled_model_path) + + # Load vision encoder + if self.vision_model_wrapper is not None: + if vision_compiled_path is None: + vision_compiled_path = os.path.join( + compiled_model_path, "vision_encoder.pt" + ) + if os.path.exists(vision_compiled_path): + self.vision_model_wrapper.load_compiled(vision_compiled_path) + # Also load CPU-side weights (patch_embed, pos_embed) + self.vision_model_wrapper.load_vision_weights_from_hf(self.model_path) + logger.info("Vision encoder loaded from pre-compiled model") + else: + logger.warning( + f"No compiled vision encoder found at {vision_compiled_path}. " + "Vision encoding will not be available." + ) + + # Qwen3.5 stop token IDs (loaded from config/tokenizer) + _DEFAULT_EOS_TOKEN_IDS = { + 248044, # <|endoftext|> -- text config eos_token_id + 248046, # <|im_end|> -- tokenizer eos_token / end of assistant turn + } + + def generate( + self, + input_ids, + attention_mask=None, + pixel_values=None, + image_grid_thw=None, + video_grid_thw=None, + max_new_tokens=32, + temperature=0.0, + top_p=1.0, + top_k=0, + eos_token_ids=None, + **kwargs, + ): + """Generate text from text and/or vision inputs. + + Args: + input_ids: (batch_size, seq_len) token IDs + attention_mask: (batch_size, seq_len) attention mask + pixel_values: Vision pixel values from HF processor (or None for text-only) + image_grid_thw: (num_images, 3) grid dimensions + video_grid_thw: (num_videos, 3) grid dimensions + max_new_tokens: Maximum new tokens to generate + temperature: Sampling temperature (0.0 = greedy/argmax) + top_p: Nucleus sampling threshold (1.0 = disabled) + top_k: Top-k sampling (0 = disabled) + eos_token_ids: Set of token IDs to stop generation on + (default: {248044, 248046}) + + Returns: + generated_ids: (batch_size, seq_len + max_new_tokens) token IDs + """ + if eos_token_ids is None: + eos_token_ids = self._DEFAULT_EOS_TOKEN_IDS + + # Reset text model state for a fresh generation. + # This ensures CTE runs (not TKG) even if a prior generate() was called. + # DeltaNet recurrent states don't need explicit zeroing because the CTE + # NKI kernel always starts from zero state. + self.text_model.reset() + + has_vision = pixel_values is not None and pixel_values.numel() > 0 + + # Step 1: Compute 3D mRoPE position IDs + if has_vision and self._vl_config is not None: + position_ids, self.rope_deltas = get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + image_token_id=self._vl_config.image_token_id, + video_token_id=self._vl_config.video_token_id, + vision_start_token_id=self._vl_config.vision_start_token_id, + spatial_merge_size=self._vl_config.spatial_merge_size, + ) + else: + # Text-only: use standard sequential position IDs + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0) + self.rope_deltas = None + + # Step 2: Run vision encoder and prepare injection args + llava_args = [] + batch_size = input_ids.shape[0] + if has_vision and self.vision_model_wrapper is not None: + # The vision encoder processes both image and video frames identically + # (they share the same ViT architecture). The HF processor outputs a + # single pixel_values tensor for images, and video frames are treated + # as multiple images with temporal grid > 1. + vision_embeddings = self.vision_model_wrapper(pixel_values, image_grid_thw) + # vision_embeddings: (total_merged_tokens, out_hidden_size) + + # Build vision_mask: boolean mask of ALL vision token positions + # (both image_token_id and video_token_id placeholders) + image_token_id = self._vl_config.image_token_id + video_token_id = self._vl_config.video_token_id + vision_bool_mask = (input_ids == image_token_id) | ( + input_ids == video_token_id + ) # (BS, seq_len) + + # For batch_size=1 (primary path): extract positions from batch element 0. + # For batch_size>1: each element may have different image token positions; + # we'd need per-element scatter. Currently only batch_size=1 is supported + # for VL (the compiled model uses batch_size=1 for CTE). + if batch_size > 1: + logger.warning( + "VL generation with batch_size > 1 is not fully supported. " + "Using batch element 0 for vision scatter positions." + ) + + positions = ( + vision_bool_mask[0].nonzero(as_tuple=False).squeeze(-1) + ) # (n_vision_tokens,) + + # Reshape vision_embeddings to (1, n_vision_tokens, hidden_size) + n_vis = positions.shape[0] + hidden_size = vision_embeddings.shape[-1] + vis_emb = vision_embeddings[:n_vis].unsqueeze(0) # (1, n_vis, hidden) + + # Pad to match input sequence length for compiled graph compatibility + seq_len = input_ids.shape[1] + pad_limit = seq_len # Must match the bucket size + + # Pad vision_embeddings to (1, pad_limit, hidden_size) + if n_vis < pad_limit: + pad_emb = torch.zeros( + (1, pad_limit - n_vis, hidden_size), + dtype=vis_emb.dtype, + ) + vis_emb_padded = torch.cat([vis_emb, pad_emb], dim=1) + else: + vis_emb_padded = vis_emb[:, :pad_limit] + + # Pad positions to (1, pad_limit, 1) with a SAFE fill value. + # CRITICAL: fill_value must be a valid index (within [0, pad_limit-1]). + # Using pad_limit-1 targets the last position (always a padding slot) + # so index_put_ scatters zero embeddings there harmlessly. + # NOTE: Do NOT use large sentinel values (e.g., 2**30) as they cause + # DGE out-of-bounds crashes in the Neuron runtime. + positions_padded = torch.full( + (1, pad_limit, 1), + fill_value=pad_limit - 1, + dtype=torch.int32, + ) + positions_padded[0, :n_vis, 0] = positions[:pad_limit].to(torch.int32) + + llava_args = [vis_emb_padded, positions_padded] + + # Append 3D mRoPE position IDs for the text model. + # position_ids shape: (3, batch_size, seq_len) from get_rope_index. + # _get_model_outputs receives this at slot 21 and pre-computes + # mRoPE cos/sin in get_model_output() for all decoder layers. + if position_ids.ndim == 3: + mrope_pos = position_ids[:, :, :seq_len].to(torch.int32).contiguous() + llava_args.append(mrope_pos) + else: + vision_embeddings = None + + # Step 3: Context encoding (prefill) + generated_ids = input_ids.clone() + + # CRITICAL: Always pass an explicit attention_mask for CTE. + # The base class _infer_attention_mask() assumes sequential position_ids + # (position_ids[i] >= i). When position_ids come from mRoPE temporal + # axis (non-sequential, e.g., all vision tokens share position 4), + # the inferred mask incorrectly masks out most of the sequence. + # Fix: provide a real all-ones mask for the actual token positions. + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # For slot 2 (position_ids): use SEQUENTIAL positions regardless of mRoPE. + # Slot 2 is only used for: (1) logit position selection via torch.max(), + # (2) attention mask inference (which we bypass with explicit mask above). + # The actual RoPE computation uses slot 21 (rotary_position_ids) from + # _get_model_outputs, NOT slot 2. Using sequential slot 2 ensures + # correct logit selection and avoids any position_ids-related issues. + seq_len = input_ids.shape[1] + cte_position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + + with torch.no_grad(): + output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=cte_position_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + llava_args=llava_args, + ) + + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Check EOS after first token + if next_token.item() in eos_token_ids: + return generated_ids + + # Step 4: Token generation (TKG) loop + for _ in range(max_new_tokens - 1): + pos_ids = torch.tensor([[generated_ids.shape[1] - 1]]) + if self.rope_deltas is not None: + pos_ids = pos_ids + self.rope_deltas + + last_token = generated_ids[:, -1:] + with torch.no_grad(): + output = self.text_model( + input_ids=last_token, + position_ids=pos_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ) + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Stop on EOS + if next_token.item() in eos_token_ids: + break + + return generated_ids + + @staticmethod + def _sample_token(logits, temperature=0.0, top_p=1.0, top_k=0): + """Sample a token from logits with optional temperature/top-p/top-k. + + Args: + logits: (batch_size, vocab_size) unnormalized logits + temperature: Sampling temperature. 0.0 = greedy (argmax). + top_p: Nucleus sampling threshold. 1.0 = disabled. + top_k: Top-k filtering. 0 = disabled. + + Returns: + token_id: (batch_size,) sampled token IDs + """ + if temperature <= 0.0: + return torch.argmax(logits, dim=-1) + + # Apply temperature + logits = logits / temperature + + # Top-k filtering + if top_k > 0: + top_k = min(top_k, logits.shape[-1]) + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = float("-inf") + + # Top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + torch.softmax(sorted_logits, dim=-1), dim=-1 + ) + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift right so the first token above threshold is kept + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = False + # Scatter back to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = float("-inf") + + # Sample from the filtered distribution + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + @staticmethod + def prepare_input_args(text_prompt, image_path, processor, role="user"): + """Prepare inputs for vision+text generation. + + Args: + text_prompt: Text prompt string + image_path: Path to image file (or None for text-only) + processor: HF AutoProcessor + role: Message role (default "user") + + Returns: + input_ids, attention_mask, vision_inputs dict + """ + content = [] + if image_path is not None: + import base64 + from pathlib import Path + + image_data = Path(image_path).read_bytes() + b64 = base64.b64encode(image_data).decode("utf-8") + content.append( + { + "type": "image", + "url": f"data:image/jpeg;base64,{b64}", + } + ) + content.append({"type": "text", "text": text_prompt}) + + messages = [{"role": role, "content": content}] + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + + vision_inputs = {} + if "pixel_values" in inputs: + vision_inputs["pixel_values"] = inputs["pixel_values"] + if "image_grid_thw" in inputs: + vision_inputs["image_grid_thw"] = inputs["image_grid_thw"] + if "video_grid_thw" in inputs: + vision_inputs["video_grid_thw"] = inputs["video_grid_thw"] + + return input_ids, attention_mask, vision_inputs diff --git a/contrib/models/Qwen3.5-27B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.5-27B/src/nki_kernels/__init__.py new file mode 100644 index 00000000..3952e26b --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/nki_kernels/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Custom NKI kernels for Qwen3.5-27B DeltaNet layers. + +Contains three kernel implementations: +- nki_deltanet: Per-token recurrent kernel (used for token generation) +- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused) +- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding) +""" diff --git a/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet.py new file mode 100644 index 00000000..a9994d54 --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet.py @@ -0,0 +1,334 @@ +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + +# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast + # 3) Multiply by k_t (128,1) which broadcasts across free dim + # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes). + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + # Each partition row gets the same delta values + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet_chunked.py new file mode 100644 index 00000000..5c582f8d --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet_chunked.py @@ -0,0 +1,320 @@ +"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill). + +Single-chunk kernel: processes one chunk (128 tokens) with Neumann-series +power-doubling for intra-chunk correction. The caller loops over chunks +in PyTorch, passing state between calls. + +Each kernel call: + - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128) + - Takes recurrent state_in (128x128) + - Returns chunk output (128x128) and state_out (128x128) + +No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles. +This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing +in the NxDI model compilation context. + +NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + + +@nki.jit +def deltanet_chunk_step( + query, # (128, 128) float32 -- one chunk, l2-normed+scaled + key, # (128, 128) float32 -- one chunk, l2-normed + value, # (128, 128) float32 -- one chunk + beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128 + g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast + g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast + state_in, # (128, 128) float32 -- recurrent state from previous chunk + lower_mask, # (128, 128) float32 -- strict lower triangular + identity, # (128, 128) float32 -- identity matrix + lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal +): + """Process one chunk of DeltaNet. + + Returns: + output: (128, 128) float32 -- chunk output + state_out: (128, 128) float32 -- updated recurrent state + """ + C, dim = query.shape # C = 128, dim = 128 + + # Output tensors in HBM + output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Load all inputs into SBUF + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value) + + beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_c, src=beta_broadcast) + + gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gc_c, src=g_cumsum) + + gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gl_c, src=g_last) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + # Load masks + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply) + + # ============================================================ + # exp(g_cumsum) and exp(-g_cumsum) + # ============================================================ + exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_gc, op=nl.exp, data=gc_c, bias=None, scale=1.0) + + neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc, + data=gc_c, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_neg_gc, op=nl.exp, data=neg_gc, bias=None, scale=1.0) + + # exp(g_last) for state decay + exp_gl = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_gl, op=nl.exp, data=gl_c, bias=None, scale=1.0) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # QK = k_beta @ k^T -- contract over features + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # ============================================================ + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_row, data1=QK, data2=exp_gc, op=nl.multiply) + + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_r_T_psum, stationary=QK_row, moving=eye) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_r_T_col, data1=QK_r_T, data2=exp_neg_gc, op=nl.multiply) + + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_d_psum, stationary=QK_r_T_col, moving=eye) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A) + + for _round in nl.sequential_range(6): + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_T_psum, stationary=A_pow, moving=eye) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=IpA_T_psum, stationary=IpA, moving=eye) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=N_T_psum, stationary=P_acc, moving=eye) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=kb_exp_gc, data1=k_beta, data2=exp_gc, op=nl.multiply) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_row, data1=qk_raw, data2=exp_gc, op=nl.multiply) + + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_r_T_psum, stationary=qk_row, moving=eye) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_r_T_col, data1=qk_r_T, data2=exp_neg_gc, op=nl.multiply) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_d_psum, stationary=qk_r_T_col, moving=eye) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply) + + # ============================================================ + # v_prime = k_cumdecay @ state + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(g_cumsum)) @ state + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_exp, data1=q_c, data2=exp_gc, op=nl.multiply) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy(dst=output, src=chunk_out) + + # ============================================================ + # State update: state_new = exp(g_last) * (state + k_raw_decay^T @ v_new) + # ============================================================ + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_raw_decay, data1=k_c, data2=exp_neg_gc, op=nl.multiply) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state_plus, data2=exp_gl, op=nl.multiply) + + nisa.dma_copy(dst=state_out, src=state_new) + + return output, state_out diff --git a/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet_fused.py new file mode 100644 index 00000000..7582356d --- /dev/null +++ b/contrib/models/Qwen3.5-27B/src/nki_kernels/nki_deltanet_fused.py @@ -0,0 +1,574 @@ +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework (same as nki_deltanet_chunked.py): + Per-chunk Neumann-series power-doubling for intra-chunk correction: + A = -QK_decay * lower_mask + N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + value_corr = N @ v_beta + k_cumdecay = N @ (k_beta * exp(gc)) + + Inter-chunk state propagation: + v_prime = k_cumdecay @ state + v_new = value_corr - v_prime + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim +CHUNK_SIZE = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1) + + +def _make_lower_mask_diag(): + """Lower triangular with diagonal (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0) + + +def _make_identity(): + """Identity matrix (128x128) as numpy constant.""" + return np.eye(CHUNK_SIZE, dtype=np.float32) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled + key: nl.ndarray, # (S, 128) float32 — l2-normed + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query must be l2-normed and scaled by 1/sqrt(k_dim) + - key must be l2-normed + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. + # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy( + dst=g_padded[0:CHUNK_SIZE, 0:1], + src=g_chunk_p[0:CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy( + dst=gc_padded[0:1, 0:CHUNK_SIZE], + src=gc_row[0:1, 0:CHUNK_SIZE], + ) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc_p, + data=gc_p, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_neg_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=neg_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # exp(g_last): scalar, then broadcast to (P_MAX, 1) + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_c) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK = k_beta^T @ k (contract over features) + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) + # Then transpose, column scale, transpose back. + # Uses tensor_scalar with (P_MAX,1) operand for row scaling. + # ============================================================ + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_row, + data=QK, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose to scale columns (now rows in transposed view) + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_r_T_col, + data=QK_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose back + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A_mat) + + for _round in nl.sequential_range(6): + # A_pow = A_pow^2: transpose A_pow, then matmul + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=IpA_T_psum, data=IpA) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + # k_beta * exp(gc): row-scaled + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_c) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + # Row-scale by exp(gc) + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_row, + data=qk_raw, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose, column-scale by exp(-gc), transpose back + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_r_T_col, + data=qk_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # v_prime = k_cumdecay @ state (state is in SBUF!) + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay = k * exp(-gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state + kv_outer + state_plus = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_plus, data1=state, data2=kv_outer, op=nl.add) + + # state = state_plus * exp(g_last) + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim + nisa.tensor_scalar( + dst=state, + data=state_plus, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.5-27B/test/__init__.py b/contrib/models/Qwen3.5-27B/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.5-27B/test/integration/__init__.py b/contrib/models/Qwen3.5-27B/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.5-27B/test/integration/test_model.py b/contrib/models/Qwen3.5-27B/test/integration/test_model.py new file mode 100644 index 00000000..b4273f87 --- /dev/null +++ b/contrib/models/Qwen3.5-27B/test/integration/test_model.py @@ -0,0 +1,469 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for Qwen3.5-27B on Neuron. + +Tests compilation, loading, inference accuracy, and performance using +the full 27B model with pre-downloaded HuggingFace weights on a trn2 instance. + +Note: A mini model option is not provided because DeltaNet layers require NKI +kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA +architecture needs at least TP=4 for the full model to fit in HBM. + +Environment variables: + QWEN35_MODEL_PATH Path to HF model weights (required) + QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_27b_traced) + QWEN35_TP_DEGREE Tensor parallelism degree (default: 4) + QWEN35_SEQ_LEN Max sequence length (default: 128) + TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000) + THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0) + +Prerequisites: + - trn2.3xlarge or larger with TP >= 4 NeuronCores available + - NXDI installed (neuronx_distributed_inference) + - HuggingFace weights downloaded to QWEN35_MODEL_PATH + - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels) + +Usage: + # Full model (trn2.3xlarge, TP=4): + QWEN35_MODEL_PATH=/mnt/models/Qwen3.5-27B \\ + QWEN35_COMPILED_PATH=/mnt/models/qwen35_traced \\ + pytest test/integration/test_model.py --capture=tee-sys +""" + +import gc +import os +import sys +import time + +import pytest +import torch + +# Ensure the contrib root (Qwen3.5-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration from environment ────────────────────────────────────── + +MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_27b_traced") +TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) +TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) +THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) + +requires_model_path = pytest.mark.skipif( + not MODEL_PATH, + reason=( + "QWEN35_MODEL_PATH not set. Integration tests require the full 27B model " + "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.5-27B to run these tests." + ), +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def model_path(): + """Return path to model weights.""" + return MODEL_PATH + + +@pytest.fixture(scope="module") +def compiled_model(model_path): + """Compile and load the model on Neuron.""" + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + # Read config.json directly (model_type 'qwen3_5' may not be in + # AutoConfig registry for all transformers versions) + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + + inf_config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, + ) + + # Compile if no existing artifacts + compiled_path = COMPILED_PATH + neff_path = os.path.join(compiled_path, "model.pt") + if not os.path.exists(neff_path): + print(f"Compiling to {compiled_path}...") + model = NeuronQwen35ForCausalLM(model_path, inf_config) + model.compile(compiled_path) + del model + gc.collect() + + # Load + print(f"Loading from {compiled_path}...") + model = NeuronQwen35ForCausalLM(compiled_path) + model.load(compiled_path) + return model + + +@pytest.fixture(scope="module") +def tokenizer(model_path): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_path, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def generation_config(tokenizer): + """Create generation config.""" + from transformers import GenerationConfig + + return GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + +def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): + """Generate text using the NXDI model.""" + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + inputs = tokenizer(prompt, padding=True, return_tensors="pt") + gen_model = HuggingFaceGenerationAdapter(model) + outputs = gen_model.generate( + inputs.input_ids, + generation_config=generation_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) + + +def _is_repetitive(text, max_repeat=5): + """Check for excessive word repetition.""" + words = text.split() + if len(words) < max_repeat: + return False + for i in range(len(words) - max_repeat + 1): + if len(set(words[i : i + max_repeat])) == 1: + return True + return False + + +# ── Smoke Tests ───────────────────────────────────────────────────────── + + +@requires_model_path +def test_model_loads(compiled_model): + """Model compiles and loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "neuron_config") + print(" Model loaded successfully") + + +@requires_model_path +def test_model_generates(compiled_model, tokenizer, generation_config): + """Model generates at least 5 tokens.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello, I am a language model", + max_new_tokens=20, + ) + input_len = len(tokenizer.encode("Hello, I am a language model")) + new_tokens = len(tokens) - input_len + assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}" + print(f" Generated {new_tokens} tokens: {text[:100]}...") + + +# ── Accuracy Tests ────────────────────────────────────────────────────── + + +@requires_model_path +def test_output_coherence(compiled_model, tokenizer, generation_config): + """Output should contain multiple words and not be excessively repetitive.""" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + words = generated.split() + assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{generated}'" + assert not _is_repetitive(generated), ( + f"Output is excessively repetitive: '{generated}'" + ) + print(f" Output coherent ({len(words)} words): {generated[:80]}...") + + +@requires_model_path +def test_top_token_valid(compiled_model, tokenizer, generation_config): + """First generated token should be a valid decodable token.""" + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello!", + max_new_tokens=1, + ) + input_len = len(tokenizer.encode("Hello!")) + first_new = tokens[input_len] + assert 0 <= first_new < tokenizer.vocab_size, ( + f"Token {first_new} out of vocab range" + ) + decoded = tokenizer.decode([first_new]) + assert len(decoded) > 0, f"Token {first_new} decoded to empty string" + print(f" First token: {first_new} -> '{decoded}'") + + +@requires_model_path +def test_capital_of_france(compiled_model, tokenizer, generation_config): + """'The capital of France is' should produce 'Paris' as first token.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=5, + ) + generated = text[len("The capital of France is") :].strip() + assert "paris" in generated.lower(), ( + f"Expected 'Paris' in output, got: '{generated}'" + ) + print(f" Capital of France: {generated}") + + +# ── Performance Tests ─────────────────────────────────────────────────── + + +@requires_model_path +def test_performance_ttft(compiled_model, tokenizer, generation_config): + """Time to first token should be within threshold.""" + prompt = "Hello, I am a language model" + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1) + + # Measure + times = [] + for _ in range(3): + t0 = time.perf_counter() + _generate( + compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1 + ) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)") + assert avg_ms < TTFT_THRESHOLD_MS, ( + f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms" + ) + + +@requires_model_path +def test_performance_throughput(compiled_model, tokenizer, generation_config): + """Throughput should meet minimum threshold.""" + prompt = "Once upon a time" + num_new_tokens = 20 + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5) + + # Measure + t0 = time.perf_counter() + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=num_new_tokens, + ) + elapsed = time.perf_counter() - t0 + + input_len = len(tokenizer.encode(prompt)) + actual_new = len(tokens) - input_len + throughput = actual_new / elapsed if elapsed > 0 else 0 + + print( + f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)" + ) + print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s") + assert throughput > THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}" + ) + + +# ── Multi-Prompt Quality Test ────────────────────────────────────────── + + +@requires_model_path +def test_multi_prompt_generation(compiled_model, tokenizer, generation_config): + """Multiple prompts should produce coherent outputs.""" + prompts = [ + "The capital of France is", + "def fibonacci(n):", + "The largest ocean on Earth is", + "To make a chocolate cake, you need", + ] + + for prompt in prompts: + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + words = generated.split() + assert len(words) >= 2, ( + f"Prompt '{prompt}' generated too few words: '{generated}'" + ) + assert not _is_repetitive(generated), ( + f"Prompt '{prompt}' produced repetitive output: '{generated}'" + ) + print(f" '{prompt[:30]}...' -> {generated[:60]}...") + + +# ── Standalone runner ─────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3.5-27B Integration Tests") + print("=" * 60) + + if not MODEL_PATH: + print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:") + print(" QWEN35_MODEL_PATH=/path/to/Qwen3.5-27B \\") + print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_traced \\") + print(" python -m pytest test/integration/test_model.py --capture=tee-sys") + sys.exit(0) + + # Setup + from transformers import AutoTokenizer, GenerationConfig as GenConfig + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + gen_cfg = GenConfig( + do_sample=True, + top_k=1, + pad_token_id=tok.pad_token_id, + eos_token_id=tok.eos_token_id, + ) + + # Build model + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + nc = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + with open(os.path.join(MODEL_PATH, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict) + + cp = COMPILED_PATH + if not os.path.exists(os.path.join(cp, "model.pt")): + print(f"Compiling to {cp}...") + m = NeuronQwen35ForCausalLM(MODEL_PATH, ic) + m.compile(cp) + del m + gc.collect() + + print(f"Loading from {cp}...") + model = NeuronQwen35ForCausalLM(cp) + model.load(cp) + + tests = [ + ("model_loads", lambda: test_model_loads(model)), + ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)), + ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)), + ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)), + ("capital_of_france", lambda: test_capital_of_france(model, tok, gen_cfg)), + ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)), + ( + "performance_throughput", + lambda: test_performance_throughput(model, tok, gen_cfg), + ), + ( + "multi_prompt_generation", + lambda: test_multi_prompt_generation(model, tok, gen_cfg), + ), + ] + + passed = 0 + for name, fn in tests: + print(f"\n--- {name} ---") + try: + fn() + print(f" PASS") + passed += 1 + except Exception as e: + print(f" FAIL: {e}") + + print(f"\n{'=' * 60}") + print(f"Results: {passed}/{len(tests)} passed") + print(f"{'=' * 60}") diff --git a/contrib/models/Qwen3.5-27B/test/unit/__init__.py b/contrib/models/Qwen3.5-27B/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.5-27B/test/unit/test_config.py b/contrib/models/Qwen3.5-27B/test/unit/test_config.py new file mode 100644 index 00000000..44ae4622 --- /dev/null +++ b/contrib/models/Qwen3.5-27B/test/unit/test_config.py @@ -0,0 +1,200 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5-27B inference configuration. + +CPU-only tests that validate config parsing, layer type setup, +DeltaNet parameter defaults, RoPE configuration, and weight conversion logic. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock + +import torch + +# Ensure the contrib root (Qwen3.5-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_config(**overrides): + """Create a Qwen35InferenceConfig with reasonable defaults.""" + neuron_config = NeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + # DeltaNet-specific + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + ) + defaults.update(overrides) + config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_hidden_size(self): + config = _make_config() + self.assertEqual(config.hidden_size, 5120) + + def test_num_hidden_layers(self): + config = _make_config() + self.assertEqual(config.num_hidden_layers, 64) + + def test_num_attention_heads(self): + config = _make_config() + self.assertEqual(config.num_attention_heads, 24) + + def test_num_key_value_heads(self): + config = _make_config() + self.assertEqual(config.num_key_value_heads, 4) + + def test_head_dim(self): + config = _make_config() + self.assertEqual(config.head_dim, 256) + + def test_intermediate_size(self): + config = _make_config() + self.assertEqual(config.intermediate_size, 17408) + + def test_vocab_size(self): + config = _make_config() + self.assertEqual(config.vocab_size, 248320) + + def test_hidden_act(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestLayerTypes(unittest.TestCase): + """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 16.""" + + def test_layer_types_length(self): + config = _make_config() + self.assertEqual(len(config.layer_types), 64) + + def test_layer_types_pattern(self): + """Every 4th layer (3, 7, 11, ...) should be full_attention.""" + config = _make_config() + for i in range(64): + expected = "full_attention" if i % 4 == 3 else "linear_attention" + self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch") + + def test_deltanet_layer_count(self): + config = _make_config() + dn_count = sum(1 for t in config.layer_types if t == "linear_attention") + self.assertEqual(dn_count, 48) + + def test_gqa_layer_count(self): + config = _make_config() + gqa_count = sum(1 for t in config.layer_types if t == "full_attention") + self.assertEqual(gqa_count, 16) + + +class TestDeltaNetConfig(unittest.TestCase): + """Test DeltaNet-specific configuration defaults.""" + + def test_linear_num_value_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_value_heads, 48) + + def test_linear_num_key_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_key_heads, 16) + + def test_linear_key_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_key_head_dim, 128) + + def test_linear_value_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_value_head_dim, 128) + + def test_linear_conv_kernel_dim(self): + config = _make_config() + self.assertEqual(config.linear_conv_kernel_dim, 4) + + +class TestRoPEConfig(unittest.TestCase): + """Test partial RoPE configuration.""" + + def test_partial_rotary_factor(self): + config = _make_config() + self.assertAlmostEqual(config.partial_rotary_factor, 0.25) + + def test_rope_dim(self): + """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64.""" + config = _make_config() + self.assertEqual(config.rope_dim, 64) + + def test_attn_output_gate(self): + config = _make_config() + self.assertTrue(config.attn_output_gate) + + def test_mrope_section(self): + config = _make_config() + self.assertEqual(config.mrope_section, [11, 11, 10]) + + def test_mrope_interleaved(self): + config = _make_config() + self.assertTrue(config.mrope_interleaved) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_neuron_config_cls(self): + """Qwen3.5-27B is dense -- uses NeuronConfig, NOT MoENeuronConfig.""" + self.assertEqual( + Qwen35InferenceConfig.get_neuron_config_cls(), + NeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("hidden_size", required) + self.assertIn("num_hidden_layers", required) + self.assertIn("linear_num_value_heads", required) + self.assertIn("linear_key_head_dim", required) + self.assertIn("layer_types", required) + + def test_output_attentions_default(self): + config = _make_config() + self.assertFalse(config.output_attentions) + + def test_output_hidden_states_default(self): + config = _make_config() + self.assertFalse(config.output_hidden_states) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.5-27B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.5-27B/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..eb8d8045 --- /dev/null +++ b/contrib/models/Qwen3.5-27B/test/unit/test_weight_conversion.py @@ -0,0 +1,434 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5-27B HF-to-NxDI weight conversion. + +CPU-only tests that validate: +- RMSNorm (+1 convention) weight conversion +- GQA q_proj interleaved split (query + gate) +- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm) +- Fused QKV concatenation +- DeltaNet layer weights pass through unchanged +- VL wrapper prefix stripping +- rank_util injection +""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + NeuronQwen35ForCausalLM, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True): + """Create a small Qwen35InferenceConfig for testing.""" + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + fused_qkv=fused_qkv, + ) + config = Qwen35InferenceConfig( + neuron_config=neuron_config, + hidden_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + intermediate_size=512, + vocab_size=1000, + rms_norm_eps=1e-6, + max_position_embeddings=4096, + rope_theta=10000, + hidden_act="silu", + linear_num_value_heads=8, + linear_num_key_heads=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_conv_kernel_dim=4, + ) + return config + + +def _make_mini_state_dict(config): + """Create a minimal HF-style state dict for conversion testing.""" + sd = {} + H = config.hidden_size # 256 + I = config.intermediate_size # 512 + V = config.vocab_size # 1000 + num_heads = config.num_attention_heads # 4 + num_kv = config.num_key_value_heads # 2 + head_dim = config.head_dim # 64 + + sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros + + for l in range(config.num_hidden_layers): + sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16) + sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros( + H, dtype=torch.bfloat16 + ) + + # Dense MLP (all layers) + sd[f"layers.{l}.mlp.gate_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.up_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.down_proj.weight"] = ( + torch.randn(H, I, dtype=torch.bfloat16) * 0.02 + ) + + if config.layer_types[l] == "full_attention": + # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...] + q_proj = ( + torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj + sd[f"layers.{l}.self_attn.k_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.v_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.o_proj.weight"] = ( + torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + else: + # DeltaNet layer: minimal required weights + key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128 + value_dim = ( + config.linear_num_value_heads * config.linear_value_head_dim + ) # 256 + conv_dim = key_dim * 2 + value_dim # 512 + sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = ( + torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = ( + torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.conv1d.weight"] = ( + torch.randn( + conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16 + ) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.A_log"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.norm.weight"] = ( + torch.randn(value_dim, dtype=torch.bfloat16) * 0.5 + ) + sd[f"layers.{l}.linear_attn.out_proj.weight"] = ( + torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02 + ) + + return sd + + +class TestNormConversion(unittest.TestCase): + """Test (+1 convention) RMSNorm weight conversion.""" + + def test_norm_weight_adds_one(self): + """Weights initialized to zero should become 1.0 after conversion.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + # norm.weight was zeros -> should now be ones + torch.testing.assert_close( + result["norm.weight"], + torch.ones_like(result["norm.weight"]), + ) + + def test_input_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.input_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} input_layernorm not converted", + ) + + def test_post_attn_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.post_attention_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} post_attention_layernorm not converted", + ) + + def test_qk_norm_adds_one(self): + """Q/K norms on GQA layers should also get +1 applied.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"] + k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"] + self.assertTrue( + torch.allclose(q_w, torch.ones_like(q_w)), + f"Layer {l} q_layernorm not converted", + ) + self.assertTrue( + torch.allclose(k_w, torch.ones_like(k_w)), + f"Layer {l} k_layernorm not converted", + ) + + +class TestQProjSplit(unittest.TestCase): + """Test q_proj interleaved split into query + gate.""" + + def test_q_proj_split_shapes(self): + """q_proj (num_heads * head_dim * 2, H) -> separate query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + # After split: q_proj should be (num_heads * head_dim, H) = (256, 256) + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + expected_shape = ( + config.num_attention_heads * config.head_dim, + config.hidden_size, + ) + self.assertEqual( + q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong" + ) + self.assertEqual( + gate_w.shape, expected_shape, f"Layer {l} gate shape wrong" + ) + + def test_q_proj_deinterleave_correct(self): + """Verify the interleaved split correctly separates query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc. + l = 3 # First full_attention layer (layer 3) + num_heads = config.num_attention_heads + head_dim = config.head_dim + H = config.hidden_size + + interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16) + for h in range(num_heads): + interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float( + h + 1 + ) # query + interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = ( + float(h + 100) + ) # gate + + sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + + for h in range(num_heads): + q_head = q_w[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query values wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong" + ) + + +class TestQKNormRename(unittest.TestCase): + """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming.""" + + def test_old_keys_removed(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result) + + def test_new_keys_present(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result) + self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result) + + +class TestFusedQKV(unittest.TestCase): + """Test fused QKV concatenation for attention layers.""" + + def test_fused_qkv_shape(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + fused_key = f"layers.{l}.self_attn.Wqkv.weight" + self.assertIn(fused_key, result, f"Layer {l} missing Wqkv") + + q_dim = config.num_attention_heads * config.head_dim + k_dim = config.num_key_value_heads * config.head_dim + v_dim = config.num_key_value_heads * config.head_dim + expected_rows = q_dim + k_dim + v_dim + self.assertEqual(result[fused_key].shape[0], expected_rows) + + def test_fused_qkv_removes_individual_keys(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result) + + +class TestDeltaNetPassthrough(unittest.TestCase): + """Test that DeltaNet layer weights pass through conversion unchanged.""" + + def test_deltanet_weights_unchanged(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Record original DeltaNet weights + originals = {} + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + originals[key] = sd[key].clone() + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for key, orig in originals.items(): + self.assertIn(key, result, f"Missing: {key}") + torch.testing.assert_close( + result[key], orig, msg=f"DeltaNet weight changed: {key}" + ) + + def test_deltanet_norm_not_converted(self): + """DeltaNet layers use standard RMSNorm (NOT +1 convention). + The norm weight should NOT be changed.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Set DeltaNet norm to a known non-zero value + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full( + (config.linear_num_value_heads * config.linear_value_head_dim,), + 0.87, + dtype=torch.bfloat16, + ) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + w = result[f"layers.{l}.linear_attn.norm.weight"] + # Should still be ~0.87, NOT 1.87 + self.assertTrue( + torch.allclose(w, torch.full_like(w, 0.87), atol=0.01), + f"Layer {l} DeltaNet norm was incorrectly modified", + ) + + +class TestRankUtil(unittest.TestCase): + """Test rank_util tensor injection.""" + + def test_rank_util_present(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + self.assertIn("rank_util.rank", result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result["rank_util.rank"], expected) + + def test_gqa_layer_rank_util(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + key = f"layers.{l}.self_attn.rank_util.rank" + self.assertIn(key, result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result[key], expected) + + +class TestVLPrefixStripping(unittest.TestCase): + """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict.""" + + def test_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Wrap with VL prefix + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"language_model.{k}"] = v + vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped + vl_sd["mtp.something"] = torch.zeros(5) # should be skipped + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertNotIn("visual.encoder.weight", result) + self.assertNotIn("mtp.something", result) + self.assertIn("norm.weight", result) + + def test_model_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"model.language_model.{k}"] = v + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertIn("norm.weight", result) + + +if __name__ == "__main__": + unittest.main()