From 3bf7a4757c579e031d9ca8b5792528b2153408a4 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 26 Mar 2026 23:19:07 -0400 Subject: [PATCH 1/4] Add whisper-large-v3-turbo contrib model Encoder-decoder Whisper Large V3 Turbo for speech-to-text on Neuron. Includes 6 optimizations: cross-attention KV cache, fused QKV projections, NKI flash attention, NKI Conv1D+GELU fusion, LNC-aware compilation, and batch size > 1 support. Experimental NKI megakernel gated behind WHISPER_USE_MEGAKERNEL=1 environment variable. --- .../models/whisper-large-v3-turbo/README.md | 168 +++ .../whisper-large-v3-turbo/src/__init__.py | 13 + .../src/modeling_whisper.py | 1297 +++++++++++++++++ .../src/utils/__init__.py | 3 + .../src/utils/config.py | 30 + .../src/utils/decoding.py | 100 ++ .../src/utils/state_dict.py | 194 +++ .../src/whisper_encoder_megakernel.py | 923 ++++++++++++ .../whisper-large-v3-turbo/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 193 +++ .../test/unit/__init__.py | 0 12 files changed, 2921 insertions(+) create mode 100644 contrib/models/whisper-large-v3-turbo/README.md create mode 100644 contrib/models/whisper-large-v3-turbo/src/__init__.py create mode 100644 contrib/models/whisper-large-v3-turbo/src/modeling_whisper.py create mode 100644 contrib/models/whisper-large-v3-turbo/src/utils/__init__.py create mode 100644 contrib/models/whisper-large-v3-turbo/src/utils/config.py create mode 100644 contrib/models/whisper-large-v3-turbo/src/utils/decoding.py create mode 100644 contrib/models/whisper-large-v3-turbo/src/utils/state_dict.py create mode 100644 contrib/models/whisper-large-v3-turbo/src/whisper_encoder_megakernel.py create mode 100644 contrib/models/whisper-large-v3-turbo/test/__init__.py create mode 100644 contrib/models/whisper-large-v3-turbo/test/integration/__init__.py create mode 100644 contrib/models/whisper-large-v3-turbo/test/integration/test_model.py create mode 100644 contrib/models/whisper-large-v3-turbo/test/unit/__init__.py diff --git a/contrib/models/whisper-large-v3-turbo/README.md b/contrib/models/whisper-large-v3-turbo/README.md new file mode 100644 index 00000000..ef22f7ea --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/README.md @@ -0,0 +1,168 @@ +# Contrib Model: whisper-large-v3-turbo + +OpenAI Whisper Large V3 Turbo (openai/whisper-large-v3-turbo) speech-to-text model for NxD Inference on AWS Neuron (Trainium2 and Inferentia2). + +This is an encoder-decoder model with separate encoder and decoder compilation. It uses the [OpenAI Whisper](https://github.com/openai/whisper) package as its base class (not HuggingFace Transformers). + +## Model Information + +| Field | Value | +|-------|-------| +| **HuggingFace ID** | [openai/whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) | +| **Model Type** | Encoder-Decoder (Speech-to-Text) | +| **Parameters** | 809M | +| **License** | MIT | +| **Architecture** | 32 encoder layers + 4 decoder layers (turbo), 1280 hidden, 20 heads | + +## Architecture Details + +- **Encoder**: 32-layer bidirectional transformer with Conv1D frontend (128 mel bins -> 1280 hidden) +- **Decoder**: 4-layer causal transformer with cross-attention to encoder output +- **Key optimizations in this implementation**: + 1. **Cross-attention K/V cache**: Skip redundant K/V projections during decode (~2.5x decode speedup, saves 19.7B FLOPs/token) + 2. **Fused QKV projections**: 3 matmuls -> 1 for self-attention + 3. **NKI flash attention (encoder)**: Bidirectional flash attention for all 32 encoder layers + 4. **NKI fused Conv1D+GELU (encoder)**: Fused conv1d kernel for encoder frontend (optional, graceful fallback) + 5. **LNC flag**: Compiler args pass `--lnc=` for LNC=1 support on trn2 + 6. **Batch size >1**: Batched decode with per-sample positional embedding and logit extraction + +## Validation Results + +| Test | Result | +|------|--------| +| Transcription Accuracy | 0% WER on reference audio | +| Cosine Similarity | N/A (speech model, validated by WER) | + +## Performance Metrics + +### Single-Stream (BS=1, trn2.3xlarge, LNC=2, bfloat16) + +| Audio Duration | Latency | Real-Time Factor | +|---------------|---------|-----------------| +| 5.0s | 180.2ms | 27.8x | +| 15.0s | 229.1ms | 65.5x | +| 30.0s | 462.9ms | 64.8x | +| 90.0s | 1102.2ms | 81.7x | + +### Batched (BS=8, trn2.3xlarge, LNC=2, bfloat16) + +| Audio Duration | Batch Latency | Per-Sample | Throughput | +|---------------|--------------|------------|------------| +| 5.0s | 630.2ms | 78.8ms | 12.69 audio-sec/wall-sec | +| 30.0s | 675.5ms | 84.4ms | 11.84 audio-sec/wall-sec | +| 90.0s | 675.0ms | 84.4ms | 11.85 audio-sec/wall-sec | + +### Data Parallel (DP=4 x BS=8, trn2.3xlarge, LNC=2, bfloat16) + +| Audio Duration | Aggregate Throughput | +|---------------|---------------------| +| 5.0s | **46.65 audio-sec/wall-sec** | +| 30.0s | **43.75 audio-sec/wall-sec** | +| 90.0s | **43.27 audio-sec/wall-sec** | + +## Usage + +```python +import os +import sys +import torch + +# Add the contrib src directory to the Python path +sys.path.insert(0, "/path/to/contrib/models/whisper-large-v3-turbo/src") + +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from modeling_whisper import WhisperInferenceConfig, NeuronApplicationWhisper + +DTYPE = torch.bfloat16 +BATCH_SIZE = 1 +TP_DEGREE = 1 +MODEL_PATH = "/home/ubuntu/models/whisper-large-v3-turbo/" +COMPILED_MODEL_PATH = "/home/ubuntu/compiled_models/whisper-large-v3-turbo/" + +# Define configs +neuron_config = NeuronConfig( + batch_size=BATCH_SIZE, + torch_dtype=DTYPE, + tp_degree=TP_DEGREE, +) +inference_config = WhisperInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), +) + +# Compile model (one-time, ~75s) +if not os.path.exists(COMPILED_MODEL_PATH): + neuron_model = NeuronApplicationWhisper(MODEL_PATH, config=inference_config) + neuron_model.compile(COMPILED_MODEL_PATH) + +# Load from compiled checkpoint (~8s) +neuron_model = NeuronApplicationWhisper(COMPILED_MODEL_PATH, config=inference_config) +neuron_model.load(COMPILED_MODEL_PATH) + +# Transcribe an audio file +result = neuron_model.transcribe("audio-sample.mp3", verbose=True) +print(result["text"]) +``` + +## Compatibility Matrix + +| Instance Type | SDK Version | TP Degree | Dtype | Status | +|--------------|-------------|-----------|-------|--------| +| trn2.3xlarge | 2.28 | 1 | bfloat16 | Validated | +| trn2.3xlarge | 2.28 | 1 | float16 | Validated | +| inf2.xlarge | 2.28 | 1 | bfloat16 | Expected compatible | +| inf2.xlarge | 2.28 | 1 | float16 | Expected compatible | + +**Notes**: +- TP=1 is recommended. Whisper (809M params) fits on a single NeuronCore. +- Higher TP degrees are supported for head-sharding but provide no benefit for this model size. +- For maximum throughput on trn2.3xlarge, use DP=4 x BS=8 with LNC=2 (4 independent model instances). +- Each batch size requires separate compilation (BS is baked into the traced graph). + +## Testing + +### Prerequisites + +```bash +pip install openai-whisper pytest +``` + +### Run integration tests + +```bash +# From the whisper-large-v3-turbo directory +pytest test/integration/test_model.py -v + +# Or run manually +python test/integration/test_model.py +``` + +### Test details + +The integration test: +1. Compiles the model (encoder + decoder) if not already compiled +2. Loads the compiled model +3. Transcribes a reference audio file +4. Validates that the transcription produces non-empty text +5. Measures transcription latency + +## Example Checkpoints + +- [openai/whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) (809M, 4 decoder layers, recommended) +- [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) (1.5B, 32 decoder layers) + +## Dependencies + +- `openai-whisper` (provides base `Whisper` class and decoding loop) +- `transformers` (for `WhisperModel.from_pretrained` weight loading and `sinusoids`) +- `neuronx-distributed-inference` (NxDI base classes, model wrapper, config) +- `nkilib` (optional, for fused Conv1D+GELU kernel) + +## Maintainer + +Jim Burtoft (jimburtoft) + +## Last Updated + +2026-03-26 diff --git a/contrib/models/whisper-large-v3-turbo/src/__init__.py b/contrib/models/whisper-large-v3-turbo/src/__init__.py new file mode 100644 index 00000000..fe27efe6 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/__init__.py @@ -0,0 +1,13 @@ +from .modeling_whisper import ( + WhisperInferenceConfig, + NeuronApplicationWhisper, + NeuronApplicationWhisperEncoder, + NeuronApplicationWhisperDecoder, +) + +__all__ = [ + "WhisperInferenceConfig", + "NeuronApplicationWhisper", + "NeuronApplicationWhisperEncoder", + "NeuronApplicationWhisperDecoder", +] diff --git a/contrib/models/whisper-large-v3-turbo/src/modeling_whisper.py b/contrib/models/whisper-large-v3-turbo/src/modeling_whisper.py new file mode 100644 index 00000000..33d99ad7 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/modeling_whisper.py @@ -0,0 +1,1297 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Modified from https://github.com/openai/whisper/blob/main/whisper/model.py +# +# Whisper (openai/whisper-large-v3-turbo) for NxD Inference. +# +# Optimizations included: +# 1. Cross-attention K/V cache: skip redundant K/V projections during decode (~2.5x decode speedup) +# 2. Fused QKV projections: 3 matmuls → 1 for self-attention +# 3. NKI flash attention (encoder): bidirectional flash attention for all 32 encoder layers +# 4. NKI fused Conv1D+GELU (encoder): fused conv1d kernel for encoder frontend +# 5. LNC flag: compiler args pass --lnc= for LNC=1 support on trn2 +# 6. Batch size >1: batched decode with per-sample positional embedding and logit extraction + +import math +import os +from typing import Optional, Iterable, List, Tuple + +import torch +from torch import Tensor, nn +import torch.nn.functional as F + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.models.model_wrapper import ( + BaseModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase + +from neuronx_distributed_inference.experimental.functional.attention.causal_attention_functions import ( + scaled_dot_product_attention_kernel, +) + +from utils.config import get_dims_from_config +from utils.decoding import decode as decode_function +from utils.state_dict import convert_hf_state_dict_to_neuron, expand_state_dict + +# NKI fused Conv1D+GELU kernel (optional — falls back to PyTorch if nkilib not available) +try: + from nkilib.experimental.conv.conv1d import conv1d as nki_conv1d + from nkilib.core.utils.common_types import ActFnType + + _HAS_NKI_CONV1D = True +except ImportError: + _HAS_NKI_CONV1D = False + +# NKI encoder megakernel (optional, experimental — gated behind WHISPER_USE_MEGAKERNEL=1). +# This fuses an entire Whisper encoder layer into a single @nki.jit kernel. +# STATUS: Negative result in benchmarks (16.9x slower than compiler-generated code due to +# weight transfer overhead and hand-written matmuls being ~13x slower than compiler-optimized). +# Preserved here as a reference implementation and starting point for future NKI kernel work. +try: + from whisper_encoder_megakernel import ( + whisper_encoder_layer_fwd, + P_MAX, + D_MODEL as MK_D_MODEL, + SEQ_PAD as MK_SEQ_PAD, + N_HEADS as MK_N_HEADS, + HEAD_DIM as MK_HEAD_DIM, + MLP_DIM as MK_MLP_DIM, + ) + + _HAS_MEGAKERNEL = True +except ImportError: + _HAS_MEGAKERNEL = False + # Dummy constants so module-level _MK_WEIGHT_SHAPES can be defined + # (only used when megakernel is active, which requires _HAS_MEGAKERNEL=True) + P_MAX = 128 + MK_D_MODEL = 1280 + MK_SEQ_PAD = 1536 + MK_N_HEADS = 20 + MK_HEAD_DIM = 64 + MK_MLP_DIM = 5120 + +from transformers import WhisperModel +from transformers.models.whisper.modeling_whisper import sinusoids +from whisper import Whisper + + +def ceil_div(a: int, b: int) -> int: + """Integer division with ceiling.""" + return -(-a // b) + + +def _tile_ln_weight(w: Tensor, p_max: int = 128) -> Tensor: + """Tile a LayerNorm weight [F] to [P_MAX, F] by repeating each row. + + The NKI megakernel expects LN weight/bias pre-tiled so that element-wise + multiply/add works on [P_MAX, F] tiles without broadcasting. + """ + return w.unsqueeze(0).expand(p_max, -1).contiguous() + + +def _prepare_megakernel_weights(block): + """Extract and reshape weights from a NeuronResidualAttentionBlock for the megakernel. + + Returns a dict of tensors ready to pass to whisper_encoder_layer_fwd. + All weights are bf16 and contiguous. + + The megakernel expects: + - LayerNorm weights tiled to [P_MAX, D_MODEL] + - QKV weight as [3*D_MODEL, D_MODEL] (already fused by NxDI state_dict conversion) + - Biases pre-tiled to [P_MAX, dim] (NKI tensor_tensor requires same shapes) + - FC1/FC2 weights in their original [out, in] layout + """ + dtype = torch.bfloat16 + + # Pre-attention LayerNorm (tiled) + attn_ln_w = _tile_ln_weight(block.attn_ln.weight.to(dtype)) + attn_ln_b = _tile_ln_weight(block.attn_ln.bias.to(dtype)) + + # Fused QKV: ColumnParallelLinear stores weight as [out_features, in_features] + # After TP sharding, this is [3*n_heads_per_tp*head_dim, n_state]. + # For TP=1: [3*1280, 1280] = [3840, 1280] + qkv_w = block.attn.qkv_proj.weight.to(dtype).contiguous() + qkv_b = _tile_ln_weight(block.attn.qkv_proj.bias.to(dtype)) # [P_MAX, 3840] + + # Output projection: RowParallelLinear, weight [n_state, n_heads_per_tp*head_dim] + # For TP=1: [1280, 1280] + out_w = block.attn.out.weight.to(dtype).contiguous() + out_b = _tile_ln_weight(block.attn.out.bias.to(dtype)) # [P_MAX, 1280] + + # Pre-MLP LayerNorm (tiled) + mlp_ln_w = _tile_ln_weight(block.mlp_ln.weight.to(dtype)) + mlp_ln_b = _tile_ln_weight(block.mlp_ln.bias.to(dtype)) + + # MLP FC1 (up_proj): ColumnParallelLinear [MLP_DIM/TP, D_MODEL] + # For TP=1: [5120, 1280] + fc1_w = block.mlp.up_proj.weight.to(dtype).contiguous() + fc1_b = _tile_ln_weight(block.mlp.up_proj.bias.to(dtype)) # [P_MAX, 5120] + + # MLP FC2 (down_proj): RowParallelLinear [D_MODEL, MLP_DIM/TP] + # For TP=1: [1280, 5120] + fc2_w = block.mlp.down_proj.weight.to(dtype).contiguous() + fc2_b = _tile_ln_weight(block.mlp.down_proj.bias.to(dtype)) # [P_MAX, 1280] + + return { + "attn_ln_w": attn_ln_w, + "attn_ln_b": attn_ln_b, + "qkv_w": qkv_w, + "qkv_b": qkv_b, + "out_w": out_w, + "out_b": out_b, + "mlp_ln_w": mlp_ln_w, + "mlp_ln_b": mlp_ln_b, + "fc1_w": fc1_w, + "fc1_b": fc1_b, + "fc2_w": fc2_w, + "fc2_b": fc2_b, + } + + +# Canonical order of weight keys for packing/unpacking +_MK_WEIGHT_KEYS = [ + "attn_ln_w", + "attn_ln_b", + "qkv_w", + "qkv_b", + "out_w", + "out_b", + "mlp_ln_w", + "mlp_ln_b", + "fc1_w", + "fc1_b", + "fc2_w", + "fc2_b", +] + +# Pre-computed shapes for each weight (TP=1, Whisper large-v3-turbo) +_MK_WEIGHT_SHAPES = { + "attn_ln_w": (P_MAX, MK_D_MODEL), # [128, 1280] + "attn_ln_b": (P_MAX, MK_D_MODEL), # [128, 1280] + "qkv_w": (3 * MK_D_MODEL, MK_D_MODEL), # [3840, 1280] + "qkv_b": (P_MAX, 3 * MK_D_MODEL), # [128, 3840] + "out_w": (MK_D_MODEL, MK_D_MODEL), # [1280, 1280] + "out_b": (P_MAX, MK_D_MODEL), # [128, 1280] + "mlp_ln_w": (P_MAX, MK_D_MODEL), # [128, 1280] + "mlp_ln_b": (P_MAX, MK_D_MODEL), # [128, 1280] + "fc1_w": (MK_MLP_DIM, MK_D_MODEL), # [5120, 1280] + "fc1_b": (P_MAX, MK_MLP_DIM), # [128, 5120] + "fc2_w": (MK_D_MODEL, MK_MLP_DIM), # [1280, 5120] + "fc2_b": (P_MAX, MK_D_MODEL), # [128, 1280] +} + +# Total elements per layer (sum of all weight numel) +_MK_ELEMENTS_PER_LAYER = sum(s[0] * s[1] for s in _MK_WEIGHT_SHAPES.values()) + + +def _pack_all_layer_weights(blocks) -> Tensor: + """Pack all megakernel weights for all layers into a single 1D bf16 tensor. + + Args: + blocks: nn.ModuleList of NeuronResidualAttentionBlock (encoder layers) + + Returns: + packed: 1D bf16 tensor of shape [n_layers * elements_per_layer] + """ + n_layers = len(blocks) + packed = torch.empty(n_layers * _MK_ELEMENTS_PER_LAYER, dtype=torch.bfloat16) + offset = 0 + for block in blocks: + weights = _prepare_megakernel_weights(block) + for key in _MK_WEIGHT_KEYS: + w = weights[key].contiguous().view(-1) + packed[offset : offset + w.numel()] = w + offset += w.numel() + assert offset == packed.numel(), f"Packing mismatch: {offset} != {packed.numel()}" + return packed + + +def _unpack_layer_weights(packed: Tensor, layer_idx: int): + """Unpack the 12 weight tensors for a single layer from the packed tensor. + + Args: + packed: 1D bf16 tensor from _pack_all_layer_weights() + layer_idx: which layer (0-indexed) + + Returns: + tuple of 12 tensors in canonical order (attn_ln_w, attn_ln_b, qkv_w, ...) + """ + base = layer_idx * _MK_ELEMENTS_PER_LAYER + offset = base + tensors = [] + for key in _MK_WEIGHT_KEYS: + shape = _MK_WEIGHT_SHAPES[key] + numel = shape[0] * shape[1] + t = packed[offset : offset + numel].view(shape) + tensors.append(t) + offset += numel + return tuple(tensors) + + +class WhisperInferenceConfig(InferenceConfig): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dims = get_dims_from_config(self) + + +class LayerNorm(nn.LayerNorm): + """ + Converts input to float32 before applying LayerNorm to avoid precision issues. + """ + + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class NeuronMLP(torch.nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + assert parallel_state.model_parallel_is_initialized(), ( + "Model parallel not initialized" + ) + self.up_proj = ColumnParallelLinear( + hidden_size, intermediate_size, bias=True, gather_output=False, dtype=dtype + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + input_is_parallel=True, + dtype=dtype, + ) + + def forward(self, x): + return self.down_proj(F.gelu(self.up_proj(x))) + + +class NeuronAttention(nn.Module): + def __init__( + self, + n_state: int, + n_head: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype = torch.float32, + kvcache=True, + ): + super().__init__() + + assert n_state % n_head == 0, ( + f"n_state ({n_state}) must be divisible by n_head ({n_head})" + ) + self.head_dim = n_state // n_head + + assert parallel_state.model_parallel_is_initialized(), ( + "Model parallel not initialized" + ) + tp_degree = parallel_state.get_tensor_model_parallel_group().size() + + # head per core + self.n_heads = ceil_div(n_head, tp_degree) + self.n_kv_heads = self.n_heads # Whisper doesn't use GQA + + # Fused QKV projection: single matmul instead of 3 separate ones. + # Bias is included for all 3 (K portion is zeroed in state dict conversion). + self.qkv_proj = ColumnParallelLinear( + n_state, + 3 * self.n_heads * tp_degree * self.head_dim, + bias=True, + gather_output=False, + dtype=dtype, + ) + self.out = RowParallelLinear( + self.n_heads * tp_degree * self.head_dim, + n_state, + bias=True, + input_is_parallel=True, + dtype=dtype, + ) + + self.cache_k = ( + nn.Parameter( + torch.zeros( + (batch_size, self.n_kv_heads, seq_len, self.head_dim), dtype=dtype + ), + requires_grad=False, + ) + if kvcache + else None + ) + self.cache_v = ( + nn.Parameter( + torch.zeros( + (batch_size, self.n_kv_heads, seq_len, self.head_dim), dtype=dtype + ), + requires_grad=False, + ) + if kvcache + else None + ) + + def forward( + self, + x: Tensor, + last_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ): + bsz, seq_len, hidden_dim = x.shape + + # Fused QKV: single matmul, then split into Q, K, V (contiguous layout) + qkv = self.qkv_proj(x) + n_state_per_tp = self.n_heads * self.head_dim + q, k, v = torch.tensor_split(qkv, (n_state_per_tp, 2 * n_state_per_tp), dim=2) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + + if self.cache_k is not None and self.cache_v is not None: + if seq_len > 1: # prefill: save all to cache + indices = torch.arange( + start=0, end=seq_len, dtype=torch.int64, device=q.device + ) + indices = indices.view(1, 1, seq_len, 1) + indices = indices.expand(bsz, self.n_kv_heads, seq_len, self.head_dim) + else: # decode: save only the last token [last_pos] to cache + indices = last_pos.view(bsz, 1, 1, 1).expand_as(k).to(torch.int64) + + updated_kcache = torch.scatter(self.cache_k, 2, indices, k) + updated_vcache = torch.scatter(self.cache_v, 2, indices, v) + + k = updated_kcache + v = updated_vcache + + if self.cache_k is None: + # Encoder path: use NKI flash attention kernel (avoids materializing + # the full 1500x1500 score matrix across all 32 encoder layers). + # Q, K, V are already in (B, H, S, d) layout from lines above. + output = scaled_dot_product_attention_kernel( + q, k, v, is_causal=False, scale=1.0 / math.sqrt(self.head_dim) + ) + # Output is (B, H, S, d) -- transpose to (B, S, H*d) + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + else: + # Decoder path: standard matmul attention (KV cache changes seq dims) + scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = torch.where(mask, scores, torch.finfo(scores.dtype).min) + scores = F.softmax(scores.float(), dim=-1).type_as(q) + output = torch.matmul(scores, v) + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + + if self.cache_k is not None and self.cache_v is not None: + return self.out(output), updated_kcache, updated_vcache + else: + return self.out(output) + + +class NeuronCrossAttention(nn.Module): + def __init__( + self, + n_state: int, + n_head: int, + batch_size: int, + kv_seq_len: int, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + + assert n_state % n_head == 0, ( + f"n_state ({n_state}) must be divisible by n_head ({n_head})" + ) + self.head_dim = n_state // n_head + + assert parallel_state.model_parallel_is_initialized(), ( + "Model parallel not initialized" + ) + tp_degree = parallel_state.get_tensor_model_parallel_group().size() + + # head per core + self.n_heads = ceil_div(n_head, tp_degree) + self.n_kv_heads = self.n_heads # Whisper doesn't use GQA + + self.query = ColumnParallelLinear( + n_state, + self.n_heads * tp_degree * self.head_dim, + bias=True, + gather_output=False, + dtype=dtype, + ) + self.key = ColumnParallelLinear( + n_state, + self.n_kv_heads * tp_degree * self.head_dim, + bias=False, # No bias for key projection + gather_output=False, + dtype=dtype, + ) + self.value = ColumnParallelLinear( + n_state, + self.n_kv_heads * tp_degree * self.head_dim, + bias=True, + gather_output=False, + dtype=dtype, + ) + self.out = RowParallelLinear( + self.n_heads * tp_degree * self.head_dim, + n_state, + bias=True, + input_is_parallel=True, + dtype=dtype, + ) + + self.cache_k = nn.Parameter( + torch.zeros( + (batch_size, self.n_kv_heads, kv_seq_len, self.head_dim), dtype=dtype + ), + requires_grad=False, + ) + self.cache_v = nn.Parameter( + torch.zeros( + (batch_size, self.n_kv_heads, kv_seq_len, self.head_dim), dtype=dtype + ), + requires_grad=False, + ) + + def forward( + self, + x: Tensor, + xa: Tensor, + is_prefill: bool = True, + ): + bsz, seq_len, hidden_dim = x.shape + + # Q projection (always needed for both prefill and decode) + q = ( + self.query(x) + .view(bsz, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + + if is_prefill: + # Prefill: compute K/V from encoder output and populate cache + kv_seq_len = xa.shape[1] + k = ( + self.key(xa) + .view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim) + .transpose(1, 2) + ) + v = ( + self.value(xa) + .view(bsz, kv_seq_len, self.n_kv_heads, self.head_dim) + .transpose(1, 2) + ) + + indices = torch.arange( + start=0, end=kv_seq_len, dtype=torch.int64, device=q.device + ) + indices = indices.view(1, 1, kv_seq_len, 1) + indices = indices.expand(bsz, self.n_kv_heads, kv_seq_len, self.head_dim) + + updated_kcache = torch.scatter(self.cache_k, 2, indices, k) + updated_vcache = torch.scatter(self.cache_v, 2, indices, v) + else: + # Decode: use cached K/V directly (no K/V projection needed, xa is unused) + updated_kcache = self.cache_k + updated_vcache = self.cache_v + + k = updated_kcache + v = updated_vcache + + # Q.K^T/√d + scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim) + scores = F.softmax(scores.float(), dim=-1).type_as(q) + output = torch.matmul(scores, v) + output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return self.out(output), updated_kcache, updated_vcache + + +class NeuronResidualAttentionBlock(nn.Module): + def __init__( + self, + n_state: int, + n_head: int, + batch_size: int, + seq_len: int, + cross_attention: bool = False, + cross_attn_seq_len: int = None, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + + self.attn = NeuronAttention( + n_state, n_head, batch_size, seq_len, dtype=dtype, kvcache=cross_attention + ) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = ( + NeuronCrossAttention( + n_state, n_head, batch_size, cross_attn_seq_len, dtype=dtype + ) + if cross_attention + else None + ) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = NeuronMLP(n_state, n_mlp, dtype=dtype) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, # "a" for audio + last_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ): + if self.cross_attn: + h, self_attn_cache_k, self_attn_cache_v = self.attn( + self.attn_ln(x), last_pos=last_pos, mask=mask + ) + else: + h = self.attn(self.attn_ln(x), last_pos=last_pos, mask=mask) + x = x + h + if self.cross_attn: + h, cross_attn_cache_k, cross_attn_cache_v = self.cross_attn( + self.cross_attn_ln(x), xa, is_prefill=x.shape[1] > 1 + ) + x = x + h + x = x + self.mlp(self.mlp_ln(x)) + + if self.cross_attn: + return ( + x, + self_attn_cache_k, + self_attn_cache_v, + cross_attn_cache_k, + cross_attn_cache_v, + ) + else: + return x + + +class NeuronAudioEncoder(nn.Module): + def __init__( + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + batch_size: int, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + seq_len = n_ctx + self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype) + self.conv2 = nn.Conv1d( + n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype + ) + self.positional_embedding = nn.Parameter( + sinusoids(n_ctx, n_state), requires_grad=False + ) + + self.blocks: Iterable[NeuronResidualAttentionBlock] = nn.ModuleList( + [ + NeuronResidualAttentionBlock( + n_state, n_head, batch_size, seq_len, dtype=dtype + ) + for _ in range(n_layer) + ] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor, packed_weights: Optional[Tensor] = None): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + packed_weights : torch.Tensor, optional, shape = (n_layers * elements_per_layer,) + Packed megakernel weights (1D bf16 tensor). Required when megakernel is active. + Created by _pack_all_layer_weights(). Passed as a forward() argument (not + nn.Parameter) because NKI kernel tracing resolves forward args but not + nn.Parameter attributes. + """ + if _HAS_NKI_CONV1D: + # NKI fused Conv1D+GELU: single kernel call per layer instead of + # separate Conv1D + GELU ops. Weights transposed from PyTorch + # (C_out, C_in, K) to NKI (K, C_in, C_out) layout. + x = nki_conv1d( + x, + self.conv1.weight.permute(2, 1, 0), + self.conv1.bias, + stride=1, + padding=(1, 1), + activation_fn=ActFnType.GELU, + ) + x = nki_conv1d( + x, + self.conv2.weight.permute(2, 1, 0), + self.conv2.bias, + stride=2, + padding=(1, 1), + activation_fn=ActFnType.GELU, + ) + else: + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding).to(x.dtype) + + use_megakernel = ( + _HAS_MEGAKERNEL + and os.environ.get("WHISPER_USE_MEGAKERNEL", "0") == "1" + and packed_weights is not None + ) + + if use_megakernel: + # Megakernel path: fuse entire encoder layer into single NKI kernel. + # Requires TP=1 and matching dimensions. + # NOTE: This is experimental and currently 16.9x slower than the default + # path due to weight transfer overhead and hand-written matmul inefficiency. + n_layers = len(self.blocks) + bsz, seq_len, hidden = x.shape + # Pad sequence from 1500 to 1536 (12 * 128) for tile alignment + if seq_len < MK_SEQ_PAD: + x = F.pad(x, (0, 0, 0, MK_SEQ_PAD - seq_len)) # pad seq dim + + # Process each batch item (megakernel is single-batch) + outputs = [] + for b in range(bsz): + x_b = x[b] # [S, D_MODEL] + for layer_idx in range(n_layers): + ( + attn_ln_w, + attn_ln_b, + qkv_w, + qkv_b, + out_w, + out_b, + mlp_ln_w, + mlp_ln_b, + fc1_w, + fc1_b, + fc2_w, + fc2_b, + ) = _unpack_layer_weights(packed_weights, layer_idx) + x_b = whisper_encoder_layer_fwd[1]( + x_b, + attn_ln_w, + attn_ln_b, + qkv_w, + qkv_b, + out_w, + out_b, + mlp_ln_w, + mlp_ln_b, + fc1_w, + fc1_b, + fc2_w, + fc2_b, + ) + outputs.append(x_b) + x = torch.stack(outputs, dim=0) # [bsz, S, D_MODEL] + + # Strip padding back to original seq_len + if seq_len < MK_SEQ_PAD: + x = x[:, :seq_len, :] + else: + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class NeuronTextDecoder(nn.Module): + def __init__( + self, + n_vocab: int, + n_text_ctx: int, + n_audio_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + batch_size: int, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.batch_size = batch_size + self.seq_len = n_text_ctx + self.vocab_size = n_vocab + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Embedding(n_text_ctx, n_state) + + self.blocks: Iterable[NeuronResidualAttentionBlock] = nn.ModuleList( + [ + NeuronResidualAttentionBlock( + n_state, + n_head, + self.batch_size, + self.seq_len, + cross_attention=True, + cross_attn_seq_len=n_audio_ctx, + dtype=dtype, + ) + for _ in range(n_layer) + ] + ) + self.ln = LayerNorm(n_state) + + def forward( + self, x: Tensor, xa: Tensor, last_pos: torch.Tensor, pad_mask: torch.Tensor + ): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + last_pos : torch.Tensor, shape = (batch_size,) + indices of the last valid token position for each sequence in the batch + pad_mask : torch.Tensor, shape = (batch_size, n_ctx) + boolean mask indicating valid positions (True) vs padded positions (False) + """ + assert x.shape[1] == 1 or x.shape[1] == self.seq_len, ( + f"Input sequence length {x.shape[1]} must be 1 (decode) or {self.seq_len} (prefill)" + ) + + is_prefill = x.shape[1] > 1 + if is_prefill: + pe = self.positional_embedding.weight + else: + # last_pos shape: (batch_size,) — index PE per sample, unsqueeze + # to (batch_size, 1, n_state) for broadcast with token embedding + pe = self.positional_embedding(last_pos).unsqueeze(1) + x = self.token_embedding(x) + pe + x = x.to(xa.dtype) + + mask = None + if is_prefill: + mask = torch.full( + (self.seq_len, self.seq_len), True, device=pad_mask.device + ).tril(diagonal=0) + input_mask = ( + pad_mask[:, None, None, :] + .expand(self.batch_size, 1, self.seq_len, self.seq_len) + .to(torch.bool) + ) + mask = torch.logical_and(mask, input_mask) + else: + mask = ( + pad_mask[:, None, None, :] + .expand(self.batch_size, 1, 1, self.seq_len) + .to(torch.bool) + ) + + self_attn_k_caches = [] + self_attn_v_caches = [] + cross_attn_k_caches = [] + cross_attn_v_caches = [] + + for block in self.blocks: + x, sk, sv, ck, cv = block(x, xa, last_pos=last_pos, mask=mask) + self_attn_k_caches.append(sk) + self_attn_v_caches.append(sv) + cross_attn_k_caches.append(ck) + cross_attn_v_caches.append(cv) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return ( + logits, + *self_attn_k_caches, + *self_attn_v_caches, + *cross_attn_k_caches, + *cross_attn_v_caches, + ) + + +class WhisperModelEncoderInstance(BaseModelInstance): + def __init__(self, config): + self.module = None + self.config = config + self.neuron_config = config.neuron_config + + def load_module(self): + dims = self.config.dims + self.module = NeuronAudioEncoder( + dims.n_mels, + dims.n_audio_ctx, + dims.n_audio_state, + dims.n_audio_head, + dims.n_audio_layer, + batch_size=self.neuron_config.batch_size, + dtype=self.neuron_config.torch_dtype, + ) + + def get(self, bucket_rank, **kwargs): + aliases = {} + return self.module, aliases + + +class WhisperModelDecoderInstance(BaseModelInstance): + def __init__(self, config): + self.module = None + self.config = config + self.neuron_config = config.neuron_config + + def load_module(self): + dims = self.config.dims + self.module = NeuronTextDecoder( + dims.n_vocab, + dims.n_text_ctx, + dims.n_audio_ctx, + dims.n_text_state, + dims.n_text_head, + dims.n_text_layer, + batch_size=self.neuron_config.batch_size, + dtype=self.neuron_config.torch_dtype, + ) + + def get(self, bucket_rank, **kwargs): + aliases = {} + output_index = 1 + for i, layer in enumerate(self.module.blocks): + aliases[layer.attn.cache_k] = output_index + output_index = output_index + 1 + for i, layer in enumerate(self.module.blocks): + aliases[layer.attn.cache_v] = output_index + output_index = output_index + 1 + for i, layer in enumerate(self.module.blocks): + aliases[layer.cross_attn.cache_k] = output_index + output_index = output_index + 1 + for i, layer in enumerate(self.module.blocks): + aliases[layer.cross_attn.cache_v] = output_index + output_index = output_index + 1 + return self.module, aliases + + +class ModelWrapperWhisperEncoder(ModelWrapper): + def __init__( + self, + config, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + model_init_kwargs={}, + ): + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, model_init_kwargs + ) + self.bucket_config = None # Set to None if no bucketing needed + self._use_megakernel = ( + _HAS_MEGAKERNEL and os.environ.get("WHISPER_USE_MEGAKERNEL", "0") == "1" + ) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + # Generate example inputs for tracing + audio = torch.randn( + self.neuron_config.batch_size, + self.config.dims.n_mels, + self.config.dims.n_audio_ctx * 2, + dtype=self.neuron_config.torch_dtype, + ) + if self._use_megakernel: + # Packed weights tensor: 1D bf16 dummy with correct total size + n_layers = self.config.dims.n_audio_layer + packed_weights = torch.zeros( + n_layers * _MK_ELEMENTS_PER_LAYER, dtype=torch.bfloat16 + ) + inputs = [(audio, packed_weights)] + else: + inputs = [(audio,)] + return inputs + + def get_model_instance(self): + return WhisperModelEncoderInstance(self.config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + +class ModelWrapperWhisperDecoderPrefill(ModelWrapper): + def __init__( + self, + config, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + model_init_kwargs={}, + ): + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, model_init_kwargs + ) + self.bucket_config = None # Set to None if no bucketing needed + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + # Generate example inputs for tracing + audio_embed = torch.randn( + self.neuron_config.batch_size, + self.config.dims.n_audio_ctx, + self.config.dims.n_audio_state, + dtype=self.neuron_config.torch_dtype, + ) + padded_tokens = torch.zeros( + (self.neuron_config.batch_size, self.config.dims.n_text_ctx), + dtype=torch.int32, + ) + last_pos = torch.zeros(self.neuron_config.batch_size, dtype=torch.int32) + pad_mask = torch.zeros( + (self.neuron_config.batch_size, self.config.dims.n_text_ctx), + dtype=torch.int32, + ) + inputs = [ + (padded_tokens, audio_embed, last_pos, pad_mask), + ] + return inputs + + def get_model_instance(self): + return WhisperModelDecoderInstance(self.config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + +class ModelWrapperWhisperDecoderDecode(ModelWrapper): + def __init__( + self, + config, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + model_init_kwargs={}, + ): + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, model_init_kwargs + ) + self.bucket_config = None # Set to None if no bucketing needed + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + # Generate example inputs for tracing. + # Use minimal dummy xa (1 token instead of n_audio_ctx) since decode reads + # cross-attention K/V from cache, not from xa. The xa tensor must be present + # for forward signature compatibility but is unused in the decode graph. + audio_embed = torch.randn( + self.neuron_config.batch_size, + 1, + self.config.dims.n_audio_state, + dtype=self.neuron_config.torch_dtype, + ) + padded_tokens = torch.zeros( + (self.neuron_config.batch_size, 1), dtype=torch.int32 + ) + last_pos = torch.zeros(self.neuron_config.batch_size, dtype=torch.int32) + pad_mask = torch.zeros( + (self.neuron_config.batch_size, self.config.dims.n_text_ctx), + dtype=torch.int32, + ) + inputs = [ + (padded_tokens, audio_embed, last_pos, pad_mask), + ] + return inputs + + def get_model_instance(self): + return WhisperModelDecoderInstance(self.config) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + +class NeuronApplicationWhisperEncoder(NeuronApplicationBase): + _model_cls = NeuronAudioEncoder + + def __init__(self, model_path, config, *args, **kwargs): + super().__init__(model_path, config, *args, **kwargs) + self.dims = config.dims + self._use_megakernel = ( + _HAS_MEGAKERNEL and os.environ.get("WHISPER_USE_MEGAKERNEL", "0") == "1" + ) + self.encoder_model = ModelWrapperWhisperEncoder( + config=self.config, + model_cls=self._model_cls, + tag="Encoder", + compiler_args=self.get_compiler_args(), + ) + self.models.append(self.encoder_model) + + # Packed megakernel weights (populated after weight loading) + self._packed_weights = None + + # workaround for whisper PyTorchInference init, dummy blocks + self.blocks = [] + + def get_compiler_args(self): + compiler_args = "--model-type=transformer" + compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'" + if self.config.neuron_config.torch_dtype == torch.float32: + compiler_args += " --auto-cast=none" + compiler_args += f" --lnc={self.config.neuron_config.logical_nc_config}" + return compiler_args + + @staticmethod + def load_hf_model(model_path): + return WhisperModel.from_pretrained(model_path) + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: WhisperInferenceConfig + ) -> dict: + state_dict = convert_hf_state_dict_to_neuron(state_dict, type="encoder") + state_dict = expand_state_dict( + state_dict, config.dims, config.neuron_config.tp_degree + ) + return state_dict + + def _build_packed_weights(self): + """Build packed megakernel weights from the HF checkpoint. + + Called after load() to create the packed weight tensor that will be + passed as a forward() input at inference time. Extracts weights directly + from the neuron state dict (avoiding instantiation of parallel layers) + and packs them into a single 1D bf16 tensor. + """ + if not self._use_megakernel: + return + + import logging + + logger = logging.getLogger(__name__) + logger.info("Building packed megakernel weights from checkpoint...") + + # Load and convert the state dict (same path as normal weight loading) + model_sd = self.checkpoint_loader_fn() + + # Extract and pack weights directly from the state dict. + dtype = torch.bfloat16 + n_layers = self.config.dims.n_audio_layer + packed = torch.empty(n_layers * _MK_ELEMENTS_PER_LAYER, dtype=dtype) + offset = 0 + + for i in range(n_layers): + prefix = f"blocks.{i}" + + # Helper to get, cast, and optionally tile a weight + def _get(key): + return model_sd[f"{prefix}.{key}"].to(dtype).contiguous() + + def _get_tiled(key): + return _tile_ln_weight(_get(key)) + + # Pack in canonical order: attn_ln_w, attn_ln_b, qkv_w, qkv_b, ... + weights_in_order = [ + _get_tiled("attn_ln.weight"), # attn_ln_w [P_MAX, D_MODEL] + _get_tiled("attn_ln.bias"), # attn_ln_b [P_MAX, D_MODEL] + _get("attn.qkv_proj.weight"), # qkv_w [3*D_MODEL, D_MODEL] + _get_tiled("attn.qkv_proj.bias"), # qkv_b [P_MAX, 3*D_MODEL] + _get("attn.out.weight"), # out_w [D_MODEL, D_MODEL] + _get_tiled("attn.out.bias"), # out_b [P_MAX, D_MODEL] + _get_tiled("mlp_ln.weight"), # mlp_ln_w [P_MAX, D_MODEL] + _get_tiled("mlp_ln.bias"), # mlp_ln_b [P_MAX, D_MODEL] + _get("mlp.up_proj.weight"), # fc1_w [MLP_DIM, D_MODEL] + _get_tiled("mlp.up_proj.bias"), # fc1_b [P_MAX, MLP_DIM] + _get("mlp.down_proj.weight"), # fc2_w [D_MODEL, MLP_DIM] + _get_tiled("mlp.down_proj.bias"), # fc2_b [P_MAX, D_MODEL] + ] + + for w in weights_in_order: + flat = w.contiguous().view(-1) + packed[offset : offset + flat.numel()] = flat + offset += flat.numel() + + assert offset == packed.numel(), ( + f"Packing mismatch: {offset} != {packed.numel()}" + ) + self._packed_weights = packed + logger.info( + f"Packed megakernel weights: {self._packed_weights.shape} " + f"({self._packed_weights.numel() * 2 / 1024 / 1024:.1f} MB)" + ) + + def load(self, compiled_model_path, *args, **kwargs): + """Override load to also build packed megakernel weights.""" + super().load(compiled_model_path, *args, **kwargs) + self._build_packed_weights() + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Whisper encoder. + :param audio: Tensor of shape (batch_size, n_mels, n_audio_ctx) + :return: Encoded audio features + """ + audio_typed = audio.to(self.config.neuron_config.torch_dtype) + if self._use_megakernel and self._packed_weights is not None: + return self.traced_model(audio_typed, self._packed_weights).to(audio.dtype) + else: + return self.traced_model(audio_typed).to(audio.dtype) + + +class NeuronApplicationWhisperDecoder(NeuronApplicationBase): + _model_cls = NeuronTextDecoder + + def __init__(self, model_path, config, *args, **kwargs): + super().__init__(model_path, config, *args, **kwargs) + self.dims = config.dims + self.decoder_prefill_model = ModelWrapperWhisperDecoderPrefill( + config=self.config, + model_cls=self._model_cls, + tag="DecoderPrefill", + compiler_args=self.get_compiler_args(), + ) + self.decoder_decode_model = ModelWrapperWhisperDecoderDecode( + config=self.config, + model_cls=self._model_cls, + tag="DecoderDecode", + compiler_args=self.get_compiler_args(), + ) + self.models.append(self.decoder_prefill_model) + self.models.append(self.decoder_decode_model) + + # workaround for whisper PyTorchInference init, dummy blocks + self.blocks = [] + + def get_compiler_args(self): + compiler_args = "--model-type=transformer" + compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'" + if self.config.neuron_config.torch_dtype == torch.float32: + compiler_args += " --auto-cast=none" + compiler_args += f" --lnc={self.config.neuron_config.logical_nc_config}" + return compiler_args + + @staticmethod + def load_hf_model(model_path): + return WhisperModel.from_pretrained(model_path) + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: WhisperInferenceConfig + ) -> dict: + state_dict = convert_hf_state_dict_to_neuron(state_dict, type="decoder") + state_dict = expand_state_dict( + state_dict, config.dims, config.neuron_config.tp_degree + ) + return state_dict + + def forward( + self, + text: torch.Tensor, + audio: torch.Tensor, + last_pos: torch.Tensor, + pad_mask: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass for the Whisper decoder. + :param text: Tensor of shape (batch_size, <= n_text_ctx) + :param audio: Encoded audio features of shape (batch_size, n_audio_ctx, n_audio_state) + :param last_pos: Tensor of shape (batch_size,) indicating the last valid token position per sample + :param pad_mask: Tensor of shape (batch_size, n_text_ctx) indicating valid positions + :return: Logits for the next token prediction + """ + return self.traced_model(text, audio, last_pos, pad_mask) + + +class NeuronApplicationWhisper(Whisper): + def __init__(self, model_path, config, *args, **kwargs): + super().__init__(config.dims) + self.config = config + self.dims = config.dims + self.encoder_path_suffix = "encoder" + self.decoder_path_suffix = "decoder" + self.encoder = NeuronApplicationWhisperEncoder( + model_path=os.path.join(model_path, self.encoder_path_suffix), + config=config, + *args, + **kwargs, + ) + self.decoder = NeuronApplicationWhisperDecoder( + model_path=os.path.join(model_path, self.decoder_path_suffix), + config=config, + *args, + **kwargs, + ) + + def compile(self, compiled_model_path, *args, **kwargs): + self.encoder.compile( + os.path.join(compiled_model_path, self.encoder_path_suffix), *args, **kwargs + ) + self.decoder.compile( + os.path.join(compiled_model_path, self.decoder_path_suffix), *args, **kwargs + ) + + def load(self, compiled_model_path, *args, **kwargs): + self.encoder.load( + os.path.join(compiled_model_path, self.encoder_path_suffix), *args, **kwargs + ) + self.decoder.load( + os.path.join(compiled_model_path, self.decoder_path_suffix), *args, **kwargs + ) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + tokens = tokens.to(torch.int32) + padded_tokens, last_pos, pad_mask = self._prepare_decoder_inputs(tokens) + is_prefill = padded_tokens.shape[1] > 1 + if is_prefill: + xa = audio_features.to(self.config.neuron_config.torch_dtype) + else: + # Decode: pass minimal dummy xa since cross-attention K/V caches + # were populated during prefill. xa is unused in the decode graph. + xa = torch.zeros( + audio_features.shape[0], + 1, + audio_features.shape[2], + dtype=self.config.neuron_config.torch_dtype, + ) + logits = self.decoder(padded_tokens, xa, last_pos, pad_mask) + if is_prefill: + # Gather logits at each sample's last valid position. + # last_pos shape: (batch_size,) — each value is the index of the + # last real token for that sample in the padded sequence. + idx = ( + last_pos.to(torch.int64).view(-1, 1, 1).expand(-1, 1, logits.shape[-1]) + ) + logits = torch.gather(logits, 1, idx) + return logits + + def _prepare_decoder_inputs(self, tokens: torch.Tensor): + pad_token = -1 + last_pos = torch.tensor( + [len(prompt) - 1 for prompt in tokens], dtype=torch.int32 + ) + padded_tokens = F.pad( + tokens, (0, self.dims.n_text_ctx - tokens.shape[1]), value=pad_token + ) + pad_mask = torch.where(padded_tokens != pad_token, 1, 0).to(torch.int32) + padded_tokens = torch.where(padded_tokens == pad_token, 0, padded_tokens) + return padded_tokens, last_pos, pad_mask + + @property + def device(self): + return torch.device("cpu") + + decode = decode_function diff --git a/contrib/models/whisper-large-v3-turbo/src/utils/__init__.py b/contrib/models/whisper-large-v3-turbo/src/utils/__init__.py new file mode 100644 index 00000000..91269ba2 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/utils/__init__.py @@ -0,0 +1,3 @@ +from .config import get_dims_from_config, LargeV3Turbo +from .decoding import decode +from .state_dict import convert_hf_state_dict_to_neuron, expand_state_dict diff --git a/contrib/models/whisper-large-v3-turbo/src/utils/config.py b/contrib/models/whisper-large-v3-turbo/src/utils/config.py new file mode 100644 index 00000000..541eb229 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/utils/config.py @@ -0,0 +1,30 @@ +from whisper.model import ModelDimensions + + +LargeV3Turbo = ModelDimensions( + n_mels=128, + n_audio_ctx=1500, + n_audio_state=1280, + n_audio_head=20, + n_audio_layer=32, + n_vocab=51866, + n_text_ctx=448, + n_text_state=1280, + n_text_head=20, + n_text_layer=4, +) + + +def get_dims_from_config(config) -> ModelDimensions: + return ModelDimensions( + n_mels=config.num_mel_bins, + n_audio_ctx=config.max_source_positions, + n_audio_state=config.d_model, + n_audio_head=config.encoder_attention_heads, + n_audio_layer=config.encoder_layers, + n_vocab=config.vocab_size, + n_text_ctx=config.max_target_positions, + n_text_state=config.d_model, + n_text_head=config.decoder_attention_heads, + n_text_layer=config.decoder_layers, + ) diff --git a/contrib/models/whisper-large-v3-turbo/src/utils/decoding.py b/contrib/models/whisper-large-v3-turbo/src/utils/decoding.py new file mode 100644 index 00000000..39b33840 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/utils/decoding.py @@ -0,0 +1,100 @@ +# Modified from https://github.com/openai/whisper/blob/main/whisper/decoding.py + +from dataclasses import replace +from typing import TYPE_CHECKING, List, Union + +import torch +from torch import Tensor + +from whisper.decoding import DecodingOptions, DecodingResult, DecodingTask, Inference + +if TYPE_CHECKING: + from modeling_whisper import NeuronApplicationWhisper as Whisper + + +class NeuronInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + tokens = tokens.to(torch.int32) + model_dtype = self.model.config.neuron_config.torch_dtype + padded_tokens, last_pos, pad_mask = self.model._prepare_decoder_inputs(tokens) + + if tokens.shape[-1] > self.initial_token_length: + # Decode: only need the last token, pass dummy xa since + # cross-attention K/V caches were populated during prefill + tokens = tokens[:, -1:] + dummy_audio = torch.zeros( + audio_features.shape[0], + 1, + audio_features.shape[2], + dtype=model_dtype, + ) + return self.model.decoder(tokens, dummy_audio, last_pos, pad_mask) + else: + # Prefill: return logits for all real (non-padded) token positions. + # The upstream _main_loop indexes logits[:, sot_index] and logits[:, -1], + # so we must return the full unpadded sequence (not just last_pos). + xa = audio_features.to(model_dtype) + tokens = padded_tokens + logits = self.model.decoder(tokens, xa, last_pos, pad_mask) + # Slice to real token length (last_pos is 0-indexed, so +1 for length). + # For batched decoding, all samples share the same initial_token_length, + # so last_pos is uniform — use the first sample's value. + seq_len = last_pos[0].item() + 1 + return logits[:, :seq_len, :] + + +class NeuronDecodingTask(DecodingTask): + def __init__(self, model: "Whisper", options: DecodingOptions): + super().__init__(model, options) + self.inference = NeuronInference(model, len(self.initial_tokens)) + + +@torch.no_grad() +def decode( + model: "Whisper", + mel: Tensor, + options: DecodingOptions = DecodingOptions(), + **kwargs, +) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) + A tensor containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + if single := mel.ndim == 2: + mel = mel.unsqueeze(0) + + if kwargs: + options = replace(options, **kwargs) + + dtype = model.config.neuron_config.torch_dtype + assert dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"Unsupported dtype: {dtype}" + ) + # For fp16, set fp16=True so upstream whisper casts audio_features to float16. + # For bfloat16, set fp16=False — we cast to bfloat16 ourselves in NeuronInference.logits(). + # (Upstream whisper only knows float16/float32; setting fp16=True with bfloat16 + # would cast to float16, causing a dtype mismatch with the traced model.) + options = replace(options, fp16=(dtype == torch.float16)) + + result = NeuronDecodingTask(model, options).run(mel) + + return result[0] if single else result diff --git a/contrib/models/whisper-large-v3-turbo/src/utils/state_dict.py b/contrib/models/whisper-large-v3-turbo/src/utils/state_dict.py new file mode 100644 index 00000000..fab94564 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/utils/state_dict.py @@ -0,0 +1,194 @@ +import torch +import torch.nn.functional as F +from collections import OrderedDict + + +def convert_hf_state_dict_to_neuron(hf_state_dict, type): + assert type in ["encoder", "decoder"], "Type must be either 'encoder' or 'decoder'." + + new_state_dict = OrderedDict() + + # First pass: rename keys (skip self-attn Q/K/V which will be fused) + for name, param in hf_state_dict.items(): + # Self-attention Q/K/V: skip individual keys (fused below) + if any(k in name for k in ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]): + continue + + # Self-attention output + if "self_attn.out_proj" in name: + name = name.replace("self_attn.out_proj", "attn.out") + + # Cross attention layers (not fused, separate Q/K/V) + elif "encoder_attn.q_proj" in name: + name = name.replace("encoder_attn.q_proj", "cross_attn.query") + elif "encoder_attn.k_proj" in name: + name = name.replace("encoder_attn.k_proj", "cross_attn.key") + elif "encoder_attn.v_proj" in name: + name = name.replace("encoder_attn.v_proj", "cross_attn.value") + elif "encoder_attn.out_proj" in name: + name = name.replace("encoder_attn.out_proj", "cross_attn.out") + + # LayerNorms + elif "self_attn_layer_norm" in name: + name = name.replace("self_attn_layer_norm", "attn_ln") + elif "final_layer_norm" in name: + name = name.replace("final_layer_norm", "mlp_ln") + elif "encoder_attn_layer_norm" in name: + name = name.replace("encoder_attn_layer_norm", "cross_attn_ln") + + # MLPs + elif "fc1" in name: + name = name.replace("fc1", "mlp.up_proj") + elif "fc2" in name: + name = name.replace("fc2", "mlp.down_proj") + + # Embedding + elif "decoder.embed_tokens" in name: + name = name.replace("decoder.embed_tokens", "decoder.token_embedding") + elif "decoder.embed_positions" in name: + name = name.replace("decoder.embed_positions.weight", "decoder.positional_embedding.weight") + elif "encoder.embed_positions" in name: + name = name.replace("encoder.embed_positions.weight", "encoder.positional_embedding") + + # Conv + elif "encoder.conv1" in name: + name = name.replace("encoder.conv1", "encoder.conv1") + elif "encoder.conv2" in name: + name = name.replace("encoder.conv2", "encoder.conv2") + + # Top-level layer norm + elif name.startswith("encoder.layer_norm"): + name = name.replace("encoder.layer_norm", "encoder.ln_post") + elif name.startswith("decoder.layer_norm"): + name = name.replace("decoder.layer_norm", "decoder.ln") + + # Layers + name = name.replace("encoder.layers.", "encoder.blocks.") + name = name.replace("decoder.layers.", "decoder.blocks.") + + prefix = type + "." + if name.startswith(prefix): + name = name[len(prefix) :] + new_state_dict[name] = param + + # Second pass: fuse self-attention Q/K/V into qkv_proj + _fuse_self_attn_qkv(hf_state_dict, new_state_dict, type) + + return new_state_dict + + +def _fuse_self_attn_qkv(hf_state_dict, new_state_dict, type): + """Fuse separate Q/K/V weights and biases into a single qkv_proj.""" + import re + + # Find all layer indices that have self-attention Q/K/V + layer_type = "encoder.layers" if type == "encoder" else "decoder.layers" + block_type = "encoder.blocks" if type == "encoder" else "decoder.blocks" + prefix = type + "." + + layer_indices = set() + pattern = re.compile(rf"{layer_type}\.(\d+)\.self_attn\.[qkv]_proj\.") + for name in hf_state_dict: + m = pattern.search(name) + if m: + layer_indices.add(int(m.group(1))) + + for idx in sorted(layer_indices): + # Fuse weights: cat([q_weight, k_weight, v_weight], dim=0) + q_w = hf_state_dict[f"{layer_type}.{idx}.self_attn.q_proj.weight"] + k_w = hf_state_dict[f"{layer_type}.{idx}.self_attn.k_proj.weight"] + v_w = hf_state_dict[f"{layer_type}.{idx}.self_attn.v_proj.weight"] + fused_w = torch.cat([q_w, k_w, v_w], dim=0) + key_w = f"{block_type}.{idx}.attn.qkv_proj.weight" + if key_w.startswith(prefix): + new_state_dict[key_w[len(prefix):]] = fused_w + + # Fuse biases: cat([q_bias, zeros_for_k, v_bias], dim=0) + # Q has bias, K has no bias, V has bias + q_b_key = f"{layer_type}.{idx}.self_attn.q_proj.bias" + v_b_key = f"{layer_type}.{idx}.self_attn.v_proj.bias" + if q_b_key in hf_state_dict: + q_b = hf_state_dict[q_b_key] + v_b = hf_state_dict[v_b_key] + k_b = torch.zeros_like(q_b) # K has no bias, use zeros + fused_b = torch.cat([q_b, k_b, v_b], dim=0) + key_b = f"{block_type}.{idx}.attn.qkv_proj.bias" + if key_b.startswith(prefix): + new_state_dict[key_b[len(prefix):]] = fused_b + + +def expand_state_dict(state_dict, dims, TP): + """ + Pad attention heads so that the number of heads is a multiple of TP. + This is necessary for the model to work correctly with tensor parallelism. + """ + if dims.n_audio_head % TP == 0: + # no need to pad + return state_dict + + new_state_dict = OrderedDict() + + d = dims.n_audio_state # embedding dim + head_dim = d // dims.n_audio_head + n_padded_heads = ((dims.n_audio_head + TP - 1) // TP) * TP + padded_d = head_dim * n_padded_heads + + for name, param in state_dict.items(): + if not isinstance(param, torch.Tensor): + new_state_dict[name] = param + continue + + shape = param.shape + + # Case 1a: Fused "qkv_proj.weight" —> [3*d, d] → [3*padded_d, d] + if "qkv_proj.weight" in name: + if shape == (3 * d, d): + q_w, k_w, v_w = torch.tensor_split(param, 3, dim=0) + q_w = F.pad(q_w, (0, 0, 0, padded_d - d)) + k_w = F.pad(k_w, (0, 0, 0, padded_d - d)) + v_w = F.pad(v_w, (0, 0, 0, padded_d - d)) + padded = torch.cat([q_w, k_w, v_w], dim=0) + new_state_dict[name] = padded + print(f"Padded {name}: {shape} → {padded.shape}") + continue + + # Case 1b: Fused "qkv_proj.bias" —> [3*d] → [3*padded_d] + if "qkv_proj.bias" in name: + if shape == (3 * d,): + q_b, k_b, v_b = torch.tensor_split(param, 3, dim=0) + q_b = F.pad(q_b, (0, padded_d - d)) + k_b = F.pad(k_b, (0, padded_d - d)) + v_b = F.pad(v_b, (0, padded_d - d)) + padded = torch.cat([q_b, k_b, v_b], dim=0) + new_state_dict[name] = padded + print(f"Padded {name}: {shape} → {padded.shape}") + continue + + # Case 2: Cross-attn "query.weight", "key.weight", "value.weight" —> [d, d] → [padded_d, d] + if any(k in name for k in ["query.weight", "key.weight", "value.weight"]): + if shape == (d, d): + padded = F.pad(param, (0, 0, 0, padded_d - d)) # pad rows + new_state_dict[name] = padded + print(f"Padded {name}: {shape} → {padded.shape}") + continue + + # Case 3: Cross-attn "query.bias", "value.bias" —> [d] → [padded_d] + if any(k in name for k in ["query.bias", "value.bias"]): + if shape == (d,): + padded = F.pad(param, (0, padded_d - d)) # pad 1D + new_state_dict[name] = padded + print(f"Padded {name}: {shape} → {padded.shape}") + continue + + # Case 4: "out.weight" —> [d, d] → [d, padded_d] + if "out.weight" in name: + if shape == (d, d): + padded = F.pad(param, (0, padded_d - d, 0, 0)) # pad columns + new_state_dict[name] = padded + print(f"Padded {name}: {shape} → {padded.shape}") + continue + + # Default: unchanged + new_state_dict[name] = param + + return new_state_dict diff --git a/contrib/models/whisper-large-v3-turbo/src/whisper_encoder_megakernel.py b/contrib/models/whisper-large-v3-turbo/src/whisper_encoder_megakernel.py new file mode 100644 index 00000000..560a87d0 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/src/whisper_encoder_megakernel.py @@ -0,0 +1,923 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Whisper encoder layer megakernel — hand-written NKI ISA. + +Fuses an entire Whisper encoder layer into a single @nki.jit kernel call, +following the Boltz-2 full_pairformer_layer_spmd approach. + +Per-layer operation sequence (32 layers total): + Phase 1: LayerNorm (pre-attention) + Phase 2: Fused QKV projection (single tiled matmul) + Phase 3: Bidirectional flash attention (online softmax, no causal mask) + Phase 4: Output projection + residual add + Phase 5: LayerNorm (pre-MLP) + Phase 6: MLP up-projection (FC1) + Phase 7: GELU activation + Phase 8: MLP down-projection (FC2) + residual add + +Whisper large-v3-turbo encoder dimensions: + hidden_dim (d_model) = 1280 (10 tiles of 128) + num_heads = 20 + head_dim = 64 (50% PE utilization on Q@K^T) + seq_len = 1500 (pad to 1536 = 12 tiles for kernel) + MLP intermediate = 5120 (40 tiles of 128) + num_layers = 32 + dtype = bf16 + +Hardware: NeuronCore v3 (trn2), 128-wide partition dimension. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + +# Whisper encoder constants +D_MODEL = 1280 # hidden dimension +N_HEADS = 20 # attention heads +HEAD_DIM = 64 # per-head dimension +MLP_DIM = 5120 # MLP intermediate dimension +SEQ_PAD = 1536 # padded sequence length (1500 -> 1536 = 12 * 128) +SEQ_ACTUAL = 1500 # actual sequence length + +# Tile counts +N_HIDDEN = D_MODEL // P_MAX # 10 tiles for hidden dim +N_MLP = MLP_DIM // P_MAX # 40 tiles for MLP intermediate +N_SEQ = SEQ_PAD // P_MAX # 12 tiles for padded sequence + + +# ============================================================================ +# Helper: Transpose SBUF -> PSUM -> SBUF +# ============================================================================ +def _transpose_to_sbuf(x): + """nc_transpose x from SBUF -> PSUM -> SBUF.""" + x_t_psum = nl.ndarray((P_MAX, P_MAX), dtype=x.dtype, buffer=nl.psum) + nisa.nc_transpose(dst=x_t_psum, data=x) + x_t = nl.ndarray((P_MAX, P_MAX), dtype=x.dtype, buffer=nl.sbuf) + nisa.tensor_copy(dst=x_t, src=x_t_psum) + return x_t + + +# ============================================================================ +# Helper: Prepare weight (load from HBM + transpose) +# ============================================================================ +def _prepare_weight(w_hbm): + """Load and transpose a [P_MAX, P_MAX] weight tile from HBM.""" + w = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=w, src=w_hbm) + w_t_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.psum) + nisa.nc_transpose(dst=w_t_psum, data=w) + w_t = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=w_t, src=w_t_psum) + return w_t + + +# ============================================================================ +# Helper: matmul with pre-transposed weight +# ============================================================================ +def _matmul_with_w_t(x_t, w_t): + """Compute x @ W^T using pre-transposed weight in SBUF. Returns bf16.""" + result_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=result_psum, stationary=x_t, moving=w_t) + result = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=result, src=result_psum) + return result + + +# ============================================================================ +# Helper: LayerNorm on a tile [P_MAX, F] in SBUF +# ============================================================================ +def _layer_norm_tile(x_tile, weight_tiled, bias_tiled, F, eps=1e-5): + """LayerNorm on a [P_MAX, F] tile entirely in SBUF. + + Args: + x_tile: [P_MAX, F] bf16 in SBUF + weight_tiled: [P_MAX, F] bf16 -- pre-tiled (each row identical) + bias_tiled: [P_MAX, F] bf16 -- pre-tiled (each row identical) + F: int, free dimension size + eps: float + Returns: + normalized: [P_MAX, F] bf16 in SBUF + """ + inv_F = 1.0 / float(F) + + # Cast to f32 + x_f32 = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=x_f32, src=x_tile) + + # Mean: reduce over free dim + sum_x = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=sum_x, op=nl.add, data=x_f32, axis=1) + mean = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=mean, data=sum_x, op0=nl.multiply, operand0=inv_F, engine=nisa.vector_engine + ) + + # Center: x - mean + centered = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=centered, + data=x_f32, + op0=nl.subtract, + operand0=mean, + engine=nisa.vector_engine, + ) + + # Variance + sq = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=sq, data1=centered, data2=centered, op=nl.multiply) + sum_sq = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=sum_sq, op=nl.add, data=sq, axis=1) + var = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=var, data=sum_sq, op0=nl.multiply, operand0=inv_F, engine=nisa.vector_engine + ) + + # rsqrt(var + eps) + var_eps = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=var_eps, data=var, op0=nl.add, operand0=eps, engine=nisa.vector_engine + ) + rsqrt_std = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=rsqrt_std, op=nl.rsqrt, data=var_eps, bias=None, scale=1.0) + + # Normalize + norm_f32 = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=norm_f32, + data=centered, + op0=nl.multiply, + operand0=rsqrt_std, + engine=nisa.vector_engine, + ) + + # Scale + bias (both [P_MAX, F], pre-tiled on host) + w_f32 = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=w_f32, src=weight_tiled) + scaled = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=scaled, data1=norm_f32, data2=w_f32, op=nl.multiply) + + b_f32 = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=b_f32, src=bias_tiled) + result_f32 = nl.ndarray((P_MAX, F), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=result_f32, data1=scaled, data2=b_f32, op=nl.add) + + result = nl.ndarray((P_MAX, F), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=result, src=result_f32) + return result + + +# ============================================================================ +# Phase: Bidirectional flash attention for one sequence tile +# ============================================================================ +def _attention_for_seq_tile( + q_hbm, # [N_SEQ * P_MAX, N_HEADS * HEAD_DIM] -- all Q in private_hbm + k_hbm, # [N_SEQ * P_MAX, N_HEADS * HEAD_DIM] -- all K in private_hbm + v_hbm, # [N_SEQ * P_MAX, N_HEADS * HEAD_DIM] -- all V in private_hbm + out_hbm, # [N_SEQ * P_MAX, N_HEADS * HEAD_DIM] -- output in private_hbm + j_tile, # query tile index (compile-time from static_range) + n_seq_tiles, # number of sequence tiles + n_heads, # number of heads + head_dim, # per-head dimension + scale, # 1/sqrt(head_dim) +): + """Bidirectional flash attention for query tile j_tile, all heads. + + Uses online softmax (Milakov & Gimelshein 2018). No causal mask. + Q/K/V layout in HBM: [seq_tiles * P_MAX, n_heads * head_dim] + Each seq tile is P_MAX positions, heads are concatenated along dim 1. + """ + Hd = n_heads * head_dim + j_start = j_tile * P_MAX + + for h in nl.affine_range(n_heads): + hd_start = h * head_dim + + # Load Q tile: [P_MAX, head_dim] from q_hbm + q_tile = nl.ndarray((P_MAX, head_dim), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_tile, + src=q_hbm[j_start : j_start + P_MAX, hd_start : hd_start + head_dim], + ) + + # Pad Q to [P_MAX, P_MAX] for nc_matmul (head_dim=64 < P_MAX=128) + q_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.memset(dst=q_padded, value=0.0) + nisa.tensor_copy(dst=q_padded[0:P_MAX, 0:head_dim], src=q_tile) + q_t = _transpose_to_sbuf(q_padded) + + # Online softmax accumulators + m_prev = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=m_prev, value=-1e30) + l_prev = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=l_prev, value=0.0) + o_acc = nl.ndarray((P_MAX, head_dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=o_acc, value=0.0) + + # Sequential over key tiles (required for online softmax) + for k_tile_idx in nl.sequential_range(n_seq_tiles): + k_start = k_tile_idx * P_MAX + + # Load K tile: [P_MAX, head_dim] + k_tile_sb = nl.ndarray((P_MAX, head_dim), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_tile_sb, + src=k_hbm[k_start : k_start + P_MAX, hd_start : hd_start + head_dim], + ) + + # Load V tile: [P_MAX, head_dim] + v_tile_sb = nl.ndarray((P_MAX, head_dim), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_tile_sb, + src=v_hbm[k_start : k_start + P_MAX, hd_start : hd_start + head_dim], + ) + + # Pad K to [P_MAX, P_MAX], transpose + k_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.memset(dst=k_padded, value=0.0) + nisa.tensor_copy(dst=k_padded[0:P_MAX, 0:head_dim], src=k_tile_sb) + k_t = _transpose_to_sbuf(k_padded) + + # Q @ K^T -> logits [P_MAX, P_MAX] + logits_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=logits_psum, stationary=q_t, moving=k_t) + logits = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=logits, src=logits_psum) + + # Scale by 1/sqrt(d) + logits_scaled = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=logits_scaled, + data=logits, + op0=nl.multiply, + operand0=scale, + engine=nisa.vector_engine, + ) + + # Online softmax step 1: tile max + tile_max = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=tile_max, op=nl.maximum, data=logits_scaled, axis=1) + + # Step 2: running max + m_new = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=m_new, data1=m_prev, data2=tile_max, op=nl.maximum) + + # Step 3: correction = exp(m_prev - m_new) + m_diff = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=m_diff, data1=m_prev, data2=m_new, op=nl.subtract) + correction = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=correction, op=nl.exp, data=m_diff, bias=None, scale=1.0 + ) + + # Step 4: exp(logits - m_new) + logits_shifted = nl.ndarray( + (P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf + ) + nisa.tensor_scalar( + dst=logits_shifted, + data=logits_scaled, + op0=nl.subtract, + operand0=m_new, + engine=nisa.vector_engine, + ) + exp_logits = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_logits, op=nl.exp, data=logits_shifted, bias=None, scale=1.0 + ) + + # Step 5: update l = l * correction + sum(exp_logits) + l_corrected = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=l_corrected, data1=l_prev, data2=correction, op=nl.multiply + ) + tile_sum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_reduce(dst=tile_sum, op=nl.add, data=exp_logits, axis=1) + l_new = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=l_new, data1=l_corrected, data2=tile_sum, op=nl.add) + + # Step 6: rescale output accumulator + o_scaled = nl.ndarray((P_MAX, head_dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=o_scaled, + data=o_acc, + op0=nl.multiply, + operand0=correction, + engine=nisa.vector_engine, + ) + + # exp_logits @ V: [P_MAX, P_MAX] @ [P_MAX, head_dim] -> [P_MAX, head_dim] + exp_bf16 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=exp_bf16, src=exp_logits) + exp_t = _transpose_to_sbuf(exp_bf16) + + pv_psum = nl.ndarray((P_MAX, head_dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=pv_psum, stationary=exp_t, moving=v_tile_sb) + pv_sbuf = nl.ndarray((P_MAX, head_dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=pv_sbuf, src=pv_psum) + + # Step 7: accumulate + nisa.tensor_tensor(dst=o_acc, data1=o_scaled, data2=pv_sbuf, op=nl.add) + + # Update running state + nisa.tensor_copy(dst=m_prev, src=m_new) + nisa.tensor_copy(dst=l_prev, src=l_new) + + # Finalize: output = o_acc / l + inv_l = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.reciprocal(dst=inv_l, data=l_prev) + o_final = nl.ndarray((P_MAX, head_dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=o_final, + data=o_acc, + op0=nl.multiply, + operand0=inv_l, + engine=nisa.vector_engine, + ) + + # Cast to bf16 and store + o_out = nl.ndarray((P_MAX, head_dim), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_out, src=o_final) + nisa.dma_copy( + dst=out_hbm[j_start : j_start + P_MAX, hd_start : hd_start + head_dim], + src=o_out, + ) + + +# ============================================================================ +# Main entry point: Whisper encoder layer megakernel +# ============================================================================ +@nki.jit +def whisper_encoder_layer_fwd( + x_in, # [N_SEQ * P_MAX, D_MODEL] bf16 -- input activations + # Pre-attention LayerNorm weights (pre-tiled to [P_MAX, D_MODEL]) + attn_ln_w, # [P_MAX, D_MODEL] bf16 + attn_ln_b, # [P_MAX, D_MODEL] bf16 + # Fused QKV weight: [3 * D_MODEL, D_MODEL] + qkv_w, # [3 * D_MODEL, D_MODEL] bf16 + qkv_b, # [P_MAX, 3 * D_MODEL] bf16 (pre-tiled, each row identical) + # Output projection: [D_MODEL, D_MODEL] + out_w, # [D_MODEL, D_MODEL] bf16 + out_b, # [P_MAX, D_MODEL] bf16 (pre-tiled, each row identical) + # Pre-MLP LayerNorm weights (pre-tiled to [P_MAX, D_MODEL]) + mlp_ln_w, # [P_MAX, D_MODEL] bf16 + mlp_ln_b, # [P_MAX, D_MODEL] bf16 + # MLP FC1 (up projection): [MLP_DIM, D_MODEL] + fc1_w, # [MLP_DIM, D_MODEL] bf16 + fc1_b, # [P_MAX, MLP_DIM] bf16 (pre-tiled, each row identical) + # MLP FC2 (down projection): [D_MODEL, MLP_DIM] + fc2_w, # [D_MODEL, MLP_DIM] bf16 + fc2_b, # [P_MAX, D_MODEL] bf16 (pre-tiled, each row identical) + # Scalars + n_seq_tiles: int = N_SEQ, + n_heads: int = N_HEADS, + head_dim: int = HEAD_DIM, + eps: float = 1e-5, +) -> nl.ndarray: + """Execute one Whisper encoder layer as a single fused kernel. + + Sequence: LN -> QKV -> FlashAttn -> OutProj -> Residual -> + LN -> FC1 -> GELU -> FC2 -> Residual + + Input x_in is [n_seq_tiles * P_MAX, D_MODEL] stored flat in HBM. + Returns x_out of the same shape. + """ + S = n_seq_tiles * P_MAX # padded sequence length + Hd = n_heads * head_dim + scale = 1.0 / (head_dim**0.5) + + # Output tensor + x_out = nl.ndarray((S, D_MODEL), dtype=nl.bfloat16, buffer=nl.shared_hbm) + + # Scratch buffers in private HBM for Q, K, V, attn_out + q_buf = nl.ndarray((S, Hd), dtype=nl.bfloat16, buffer=nl.private_hbm) + k_buf = nl.ndarray((S, Hd), dtype=nl.bfloat16, buffer=nl.private_hbm) + v_buf = nl.ndarray((S, Hd), dtype=nl.bfloat16, buffer=nl.private_hbm) + attn_out_buf = nl.ndarray((S, Hd), dtype=nl.bfloat16, buffer=nl.private_hbm) + # Scratch for post-attention residual (need original x for residual add) + post_attn_buf = nl.ndarray((S, D_MODEL), dtype=nl.bfloat16, buffer=nl.private_hbm) + + # Load LayerNorm weights once (shared across all seq tiles) + ln1_w = nl.ndarray((P_MAX, D_MODEL), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=ln1_w, src=attn_ln_w) + ln1_b = nl.ndarray((P_MAX, D_MODEL), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=ln1_b, src=attn_ln_b) + + # ================================================================ + # Phase 1-2: LayerNorm + QKV projection for each sequence tile + # QKV weight is [3*D_MODEL, D_MODEL] = [3840, 1280]. + # For each seq tile: x_normed [P_MAX, 1280] @ W_qkv^T [1280, 3840] + # Output: 3 * 10 = 30 output chunks of [P_MAX, P_MAX] + # We split into Q[P_MAX, Hd], K[P_MAX, Hd], V[P_MAX, Hd] + # where Hd = n_heads * head_dim = 1280 + # ================================================================ + for s_tile in nl.sequential_range(n_seq_tiles): + s_start = s_tile * P_MAX + + # Load input tile [P_MAX, D_MODEL] + x_tile = nl.ndarray((P_MAX, D_MODEL), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=x_tile, src=x_in[s_start : s_start + P_MAX, 0:D_MODEL]) + + # LayerNorm + x_normed = _layer_norm_tile(x_tile, ln1_w, ln1_b, D_MODEL, eps) + + # Split normed input into N_HIDDEN chunks and transpose each + # x_normed is [P_MAX, 1280] = [P_MAX, 10*128] + # We need x_chunks_t: tuple of 10 transposed [P_MAX, P_MAX] tiles + xc_t_0 = _transpose_to_sbuf( + nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + ) + # Can't do that -- need to copy first, then transpose. + # Let me do it properly: + + # Extract and transpose each hidden chunk + xc_0 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_0, src=x_normed[0:P_MAX, 0:P_MAX]) + xc_t_0 = _transpose_to_sbuf(xc_0) + + xc_1 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_1, src=x_normed[0:P_MAX, P_MAX : 2 * P_MAX]) + xc_t_1 = _transpose_to_sbuf(xc_1) + + xc_2 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_2, src=x_normed[0:P_MAX, 2 * P_MAX : 3 * P_MAX]) + xc_t_2 = _transpose_to_sbuf(xc_2) + + xc_3 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_3, src=x_normed[0:P_MAX, 3 * P_MAX : 4 * P_MAX]) + xc_t_3 = _transpose_to_sbuf(xc_3) + + xc_4 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_4, src=x_normed[0:P_MAX, 4 * P_MAX : 5 * P_MAX]) + xc_t_4 = _transpose_to_sbuf(xc_4) + + xc_5 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_5, src=x_normed[0:P_MAX, 5 * P_MAX : 6 * P_MAX]) + xc_t_5 = _transpose_to_sbuf(xc_5) + + xc_6 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_6, src=x_normed[0:P_MAX, 6 * P_MAX : 7 * P_MAX]) + xc_t_6 = _transpose_to_sbuf(xc_6) + + xc_7 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_7, src=x_normed[0:P_MAX, 7 * P_MAX : 8 * P_MAX]) + xc_t_7 = _transpose_to_sbuf(xc_7) + + xc_8 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_8, src=x_normed[0:P_MAX, 8 * P_MAX : 9 * P_MAX]) + xc_t_8 = _transpose_to_sbuf(xc_8) + + xc_9 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=xc_9, src=x_normed[0:P_MAX, 9 * P_MAX : 10 * P_MAX]) + xc_t_9 = _transpose_to_sbuf(xc_9) + + xc_t = ( + xc_t_0, + xc_t_1, + xc_t_2, + xc_t_3, + xc_t_4, + xc_t_5, + xc_t_6, + xc_t_7, + xc_t_8, + xc_t_9, + ) + + # QKV matmul: for each of 30 output chunks (3*10), accumulate across + # 10 input chunks. Output goes to Q/K/V buffers in private_hbm. + # qkv_w is [3*D_MODEL, D_MODEL] = [3840, 1280] + # Layout: rows 0..1279 = Q weights, 1280..2559 = K, 2560..3839 = V + # Each section: 10 output chunks of [P_MAX, P_MAX] + # For output chunk o (0..9 for Q, 10..19 for K, 20..29 for V): + # acc = sum over i=0..9 of: x_chunks_t[i] @ W[o*128:(o+1)*128, i*128:(i+1)*128]^T + + # Q projection: 10 output chunks + for o in nl.static_range(N_HIDDEN): + q_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=q_acc, value=0.0) + for i in nl.static_range(N_HIDDEN): + w_t = _prepare_weight( + qkv_w[o * P_MAX : (o + 1) * P_MAX, i * P_MAX : (i + 1) * P_MAX] + ) + p = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=p, stationary=xc_t[i], moving=w_t) + ps = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ps, src=p) + nisa.tensor_tensor(dst=q_acc, data1=q_acc, data2=ps, op=nl.add) + # Add bias and store + q_out = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_out, src=q_acc) + # Bias: load [P_MAX, P_MAX] slice (pre-tiled on host) + q_bias = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=q_bias, src=qkv_b[0:P_MAX, o * P_MAX : (o + 1) * P_MAX]) + q_biased = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=q_biased, data1=q_out, data2=q_bias, op=nl.add) + nisa.dma_copy( + dst=q_buf[s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX], + src=q_biased, + ) + + # K projection: 10 output chunks (rows D_MODEL..2*D_MODEL of qkv_w) + for o in nl.static_range(N_HIDDEN): + k_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=k_acc, value=0.0) + w_row_off = D_MODEL + o * P_MAX + for i in nl.static_range(N_HIDDEN): + w_t = _prepare_weight( + qkv_w[w_row_off : w_row_off + P_MAX, i * P_MAX : (i + 1) * P_MAX] + ) + p = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=p, stationary=xc_t[i], moving=w_t) + ps = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ps, src=p) + nisa.tensor_tensor(dst=k_acc, data1=k_acc, data2=ps, op=nl.add) + k_out = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_out, src=k_acc) + k_bias = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_bias, + src=qkv_b[0:P_MAX, D_MODEL + o * P_MAX : D_MODEL + (o + 1) * P_MAX], + ) + k_biased = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_biased, data1=k_out, data2=k_bias, op=nl.add) + nisa.dma_copy( + dst=k_buf[s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX], + src=k_biased, + ) + + # V projection: 10 output chunks (rows 2*D_MODEL..3*D_MODEL of qkv_w) + for o in nl.static_range(N_HIDDEN): + v_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=v_acc, value=0.0) + w_row_off = 2 * D_MODEL + o * P_MAX + for i in nl.static_range(N_HIDDEN): + w_t = _prepare_weight( + qkv_w[w_row_off : w_row_off + P_MAX, i * P_MAX : (i + 1) * P_MAX] + ) + p = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=p, stationary=xc_t[i], moving=w_t) + ps = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ps, src=p) + nisa.tensor_tensor(dst=v_acc, data1=v_acc, data2=ps, op=nl.add) + v_out = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_out, src=v_acc) + v_bias = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_bias, + src=qkv_b[ + 0:P_MAX, 2 * D_MODEL + o * P_MAX : 2 * D_MODEL + (o + 1) * P_MAX + ], + ) + v_biased = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_biased, data1=v_out, data2=v_bias, op=nl.add) + nisa.dma_copy( + dst=v_buf[s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX], + src=v_biased, + ) + + # ================================================================ + # Phase 3: Bidirectional flash attention + # Q/K/V are in private_hbm: [S, Hd] where Hd = n_heads * head_dim + # Process each query tile independently (parallel over seq tiles) + # ================================================================ + for j_tile in nl.static_range(N_SEQ): + _attention_for_seq_tile( + q_buf, + k_buf, + v_buf, + attn_out_buf, + j_tile, + n_seq_tiles, + n_heads, + head_dim, + scale, + ) + + # ================================================================ + # Phase 4: Output projection + residual add + # attn_out is [S, Hd=1280], out_w is [D_MODEL, D_MODEL] = [1280, 1280] + # For each seq tile: attn_out[P_MAX, 1280] @ out_w^T -> [P_MAX, 1280] + # Then add bias and residual (original x_in) + # Result stored to post_attn_buf for Phase 5-8 + # ================================================================ + for s_tile in nl.sequential_range(n_seq_tiles): + s_start = s_tile * P_MAX + + # Load attn output tile, split into chunks, transpose each + ao_c_0 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=ao_c_0, src=attn_out_buf[s_start : s_start + P_MAX, 0:P_MAX]) + ao_t_0 = _transpose_to_sbuf(ao_c_0) + + ao_c_1 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_1, src=attn_out_buf[s_start : s_start + P_MAX, P_MAX : 2 * P_MAX] + ) + ao_t_1 = _transpose_to_sbuf(ao_c_1) + + ao_c_2 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_2, + src=attn_out_buf[s_start : s_start + P_MAX, 2 * P_MAX : 3 * P_MAX], + ) + ao_t_2 = _transpose_to_sbuf(ao_c_2) + + ao_c_3 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_3, + src=attn_out_buf[s_start : s_start + P_MAX, 3 * P_MAX : 4 * P_MAX], + ) + ao_t_3 = _transpose_to_sbuf(ao_c_3) + + ao_c_4 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_4, + src=attn_out_buf[s_start : s_start + P_MAX, 4 * P_MAX : 5 * P_MAX], + ) + ao_t_4 = _transpose_to_sbuf(ao_c_4) + + ao_c_5 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_5, + src=attn_out_buf[s_start : s_start + P_MAX, 5 * P_MAX : 6 * P_MAX], + ) + ao_t_5 = _transpose_to_sbuf(ao_c_5) + + ao_c_6 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_6, + src=attn_out_buf[s_start : s_start + P_MAX, 6 * P_MAX : 7 * P_MAX], + ) + ao_t_6 = _transpose_to_sbuf(ao_c_6) + + ao_c_7 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_7, + src=attn_out_buf[s_start : s_start + P_MAX, 7 * P_MAX : 8 * P_MAX], + ) + ao_t_7 = _transpose_to_sbuf(ao_c_7) + + ao_c_8 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_8, + src=attn_out_buf[s_start : s_start + P_MAX, 8 * P_MAX : 9 * P_MAX], + ) + ao_t_8 = _transpose_to_sbuf(ao_c_8) + + ao_c_9 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=ao_c_9, + src=attn_out_buf[s_start : s_start + P_MAX, 9 * P_MAX : 10 * P_MAX], + ) + ao_t_9 = _transpose_to_sbuf(ao_c_9) + + ao_t = ( + ao_t_0, + ao_t_1, + ao_t_2, + ao_t_3, + ao_t_4, + ao_t_5, + ao_t_6, + ao_t_7, + ao_t_8, + ao_t_9, + ) + + # Tiled matmul: for each output chunk, accumulate across 10 input chunks + for o in nl.static_range(N_HIDDEN): + out_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=out_acc, value=0.0) + for i in nl.static_range(N_HIDDEN): + w_t = _prepare_weight( + out_w[o * P_MAX : (o + 1) * P_MAX, i * P_MAX : (i + 1) * P_MAX] + ) + p = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=p, stationary=ao_t[i], moving=w_t) + ps = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ps, src=p) + nisa.tensor_tensor(dst=out_acc, data1=out_acc, data2=ps, op=nl.add) + + # Cast, add bias, add residual + proj_chunk = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=proj_chunk, src=out_acc) + ob = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=ob, src=out_b[0:P_MAX, o * P_MAX : (o + 1) * P_MAX]) + proj_biased = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=proj_biased, data1=proj_chunk, data2=ob, op=nl.add) + x_orig = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=x_orig, + src=x_in[s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX], + ) + x_res = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=x_res, data1=x_orig, data2=proj_biased, op=nl.add) + nisa.dma_copy( + dst=post_attn_buf[ + s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX + ], + src=x_res, + ) + + # ================================================================ + # Phase 5-8: LayerNorm + MLP (FC1 -> GELU -> FC2) + residual + # post_attn_buf has the post-attention residual output. + # FC1: [MLP_DIM, D_MODEL] = [5120, 1280], 40 out x 10 in chunks + # FC2: [D_MODEL, MLP_DIM] = [1280, 5120], 10 out x 40 in chunks + # ================================================================ + # Load MLP LayerNorm weights + ln2_w = nl.ndarray((P_MAX, D_MODEL), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=ln2_w, src=mlp_ln_w) + ln2_b = nl.ndarray((P_MAX, D_MODEL), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=ln2_b, src=mlp_ln_b) + + for s_tile in nl.sequential_range(n_seq_tiles): + s_start = s_tile * P_MAX + + # Load post-attention tile + pa_tile = nl.ndarray((P_MAX, D_MODEL), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=pa_tile, + src=post_attn_buf[s_start : s_start + P_MAX, 0:D_MODEL], + ) + + # LayerNorm + x_normed = _layer_norm_tile(pa_tile, ln2_w, ln2_b, D_MODEL, eps) + + # Split and transpose normed input (10 chunks) + mc_0 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_0, src=x_normed[0:P_MAX, 0:P_MAX]) + mc_t_0 = _transpose_to_sbuf(mc_0) + + mc_1 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_1, src=x_normed[0:P_MAX, P_MAX : 2 * P_MAX]) + mc_t_1 = _transpose_to_sbuf(mc_1) + + mc_2 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_2, src=x_normed[0:P_MAX, 2 * P_MAX : 3 * P_MAX]) + mc_t_2 = _transpose_to_sbuf(mc_2) + + mc_3 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_3, src=x_normed[0:P_MAX, 3 * P_MAX : 4 * P_MAX]) + mc_t_3 = _transpose_to_sbuf(mc_3) + + mc_4 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_4, src=x_normed[0:P_MAX, 4 * P_MAX : 5 * P_MAX]) + mc_t_4 = _transpose_to_sbuf(mc_4) + + mc_5 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_5, src=x_normed[0:P_MAX, 5 * P_MAX : 6 * P_MAX]) + mc_t_5 = _transpose_to_sbuf(mc_5) + + mc_6 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_6, src=x_normed[0:P_MAX, 6 * P_MAX : 7 * P_MAX]) + mc_t_6 = _transpose_to_sbuf(mc_6) + + mc_7 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_7, src=x_normed[0:P_MAX, 7 * P_MAX : 8 * P_MAX]) + mc_t_7 = _transpose_to_sbuf(mc_7) + + mc_8 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_8, src=x_normed[0:P_MAX, 8 * P_MAX : 9 * P_MAX]) + mc_t_8 = _transpose_to_sbuf(mc_8) + + mc_9 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=mc_9, src=x_normed[0:P_MAX, 9 * P_MAX : 10 * P_MAX]) + mc_t_9 = _transpose_to_sbuf(mc_9) + + mc_t = ( + mc_t_0, + mc_t_1, + mc_t_2, + mc_t_3, + mc_t_4, + mc_t_5, + mc_t_6, + mc_t_7, + mc_t_8, + mc_t_9, + ) + + # FC1 + GELU + FC2 fused with accumulation into output chunks + # FC2 accumulators: 10 output chunks (accumulated across 40 hidden chunks) + fc2_acc_0 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_0, value=0.0) + fc2_acc_1 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_1, value=0.0) + fc2_acc_2 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_2, value=0.0) + fc2_acc_3 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_3, value=0.0) + fc2_acc_4 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_4, value=0.0) + fc2_acc_5 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_5, value=0.0) + fc2_acc_6 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_6, value=0.0) + fc2_acc_7 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_7, value=0.0) + fc2_acc_8 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_8, value=0.0) + fc2_acc_9 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc2_acc_9, value=0.0) + + fc2_acc = ( + fc2_acc_0, + fc2_acc_1, + fc2_acc_2, + fc2_acc_3, + fc2_acc_4, + fc2_acc_5, + fc2_acc_6, + fc2_acc_7, + fc2_acc_8, + fc2_acc_9, + ) + + # For each hidden chunk h (0..39): + # 1. FC1: accumulate across 10 input chunks -> [P_MAX, P_MAX] + # 2. Add FC1 bias + # 3. GELU activation + # 4. FC2: accumulate this hidden chunk into each of 10 output chunks + for h in nl.static_range(N_MLP): + # FC1 for hidden chunk h: accumulate across 10 input chunks + fc1_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=fc1_acc, value=0.0) + for i in nl.static_range(N_HIDDEN): + w_t = _prepare_weight( + fc1_w[h * P_MAX : (h + 1) * P_MAX, i * P_MAX : (i + 1) * P_MAX] + ) + p = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=p, stationary=mc_t[i], moving=w_t) + ps = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ps, src=p) + nisa.tensor_tensor(dst=fc1_acc, data1=fc1_acc, data2=ps, op=nl.add) + + # Add FC1 bias + fc1_bf16 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=fc1_bf16, src=fc1_acc) + fb1 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=fb1, src=fc1_b[0:P_MAX, h * P_MAX : (h + 1) * P_MAX]) + fc1_biased = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=fc1_biased, data1=fc1_bf16, data2=fb1, op=nl.add) + + # GELU approximation: x * sigmoid(1.702 * x) + # This matches PyTorch's F.gelu() "tanh" approximation closely + scaled_x = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=scaled_x, + data=fc1_biased, + op0=nl.multiply, + operand0=1.702, + engine=nisa.vector_engine, + ) + sig_x = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.activation( + dst=sig_x, op=nl.sigmoid, data=scaled_x, bias=None, scale=1.0 + ) + gelu_out = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gelu_out, data1=fc1_biased, data2=sig_x, op=nl.multiply + ) + + # FC2: accumulate this hidden chunk's contribution to each output chunk + gelu_t = _transpose_to_sbuf(gelu_out) + for o in nl.static_range(N_HIDDEN): + w_t = _prepare_weight( + fc2_w[o * P_MAX : (o + 1) * P_MAX, h * P_MAX : (h + 1) * P_MAX] + ) + p = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=p, stationary=gelu_t, moving=w_t) + ps = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ps, src=p) + nisa.tensor_tensor( + dst=fc2_acc[o], data1=fc2_acc[o], data2=ps, op=nl.add + ) + + # Add FC2 bias, residual, and store to output + for o in nl.static_range(N_HIDDEN): + fc2_chunk = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(dst=fc2_chunk, src=fc2_acc[o]) + fb2 = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy(dst=fb2, src=fc2_b[0:P_MAX, o * P_MAX : (o + 1) * P_MAX]) + fc2_biased = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=fc2_biased, data1=fc2_chunk, data2=fb2, op=nl.add) + + # Residual: add post-attention value + pa_orig = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.dma_copy( + dst=pa_orig, + src=post_attn_buf[ + s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX + ], + ) + x_final = nl.ndarray((P_MAX, P_MAX), dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_tensor(dst=x_final, data1=pa_orig, data2=fc2_biased, op=nl.add) + nisa.dma_copy( + dst=x_out[s_start : s_start + P_MAX, o * P_MAX : (o + 1) * P_MAX], + src=x_final, + ) + + return x_out diff --git a/contrib/models/whisper-large-v3-turbo/test/__init__.py b/contrib/models/whisper-large-v3-turbo/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/whisper-large-v3-turbo/test/integration/__init__.py b/contrib/models/whisper-large-v3-turbo/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py b/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py new file mode 100644 index 00000000..649e3c73 --- /dev/null +++ b/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py @@ -0,0 +1,193 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration test for Whisper Large V3 Turbo on Neuron. + +This test validates the Whisper encoder-decoder model by: +1. Compiling the model (encoder + decoder) +2. Loading from compiled checkpoint +3. Transcribing a reference audio file +4. Validating transcription accuracy + +Prerequisites: + pip install openai-whisper pytest + +Usage: + # Run with pytest + pytest test/integration/test_model.py -v + + # Run directly + python test/integration/test_model.py +""" + +import os +import sys +import time + +import pytest +import torch + +# Add the src directory to the path +sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src"))) + +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from modeling_whisper import WhisperInferenceConfig, NeuronApplicationWhisper + +# Configuration +MODEL_PATH = os.environ.get( + "WHISPER_MODEL_PATH", "/home/ubuntu/models/whisper-large-v3-turbo/" +) +COMPILED_MODEL_PATH = os.environ.get( + "WHISPER_COMPILED_PATH", "/home/ubuntu/compiled_models/whisper-large-v3-turbo/" +) +AUDIO_FILE = os.environ.get( + "WHISPER_AUDIO_FILE", + os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "..", + "examples", + "audio-sample.mp3", + ), +) +DTYPE = torch.bfloat16 +BATCH_SIZE = 1 +TP_DEGREE = 1 + + +def _get_config(): + """Create NeuronConfig and WhisperInferenceConfig.""" + neuron_config = NeuronConfig( + batch_size=BATCH_SIZE, + torch_dtype=DTYPE, + tp_degree=TP_DEGREE, + ) + inference_config = WhisperInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + return inference_config + + +@pytest.fixture(scope="module") +def compiled_model(): + """Compile and load the Whisper model (module-scoped for reuse across tests).""" + config = _get_config() + + # Compile if needed + if not os.path.exists(COMPILED_MODEL_PATH): + print(f"\nCompiling Whisper model to {COMPILED_MODEL_PATH}...") + model = NeuronApplicationWhisper(MODEL_PATH, config=config) + model.compile(COMPILED_MODEL_PATH) + + # Load compiled model + print(f"\nLoading compiled Whisper model from {COMPILED_MODEL_PATH}...") + model = NeuronApplicationWhisper(COMPILED_MODEL_PATH, config=config) + model.load(COMPILED_MODEL_PATH) + return model + + +def test_model_loads(compiled_model): + """Smoke test: model loads successfully.""" + assert compiled_model is not None + assert compiled_model.encoder is not None + assert compiled_model.decoder is not None + + +def test_model_transcribes(compiled_model): + """Test that the model produces a non-empty transcription.""" + assert os.path.exists(AUDIO_FILE), ( + f"Audio file not found: {AUDIO_FILE}. " + f"Set WHISPER_AUDIO_FILE environment variable to point to a valid audio file." + ) + result = compiled_model.transcribe(AUDIO_FILE) + text = result["text"].strip() + print(f"\nTranscription: {text}") + assert len(text) > 0, "Transcription should not be empty" + + +def test_transcription_latency(compiled_model): + """Measure transcription latency with warmup.""" + assert os.path.exists(AUDIO_FILE), f"Audio file not found: {AUDIO_FILE}" + + # Warmup + compiled_model.transcribe(AUDIO_FILE) + + # Measure + n_runs = 3 + latencies = [] + for _ in range(n_runs): + start = time.perf_counter() + compiled_model.transcribe(AUDIO_FILE) + latencies.append(time.perf_counter() - start) + + avg_latency = sum(latencies) / len(latencies) + print( + f"\nAverage transcription latency ({n_runs} runs): {avg_latency * 1000:.1f}ms" + ) + # Basic sanity: should complete within 10 seconds for any reasonable audio + assert avg_latency < 10.0, f"Transcription too slow: {avg_latency:.1f}s" + + +def test_transcription_deterministic(compiled_model): + """Test that repeated transcriptions produce the same result.""" + assert os.path.exists(AUDIO_FILE), f"Audio file not found: {AUDIO_FILE}" + + result1 = compiled_model.transcribe(AUDIO_FILE) + result2 = compiled_model.transcribe(AUDIO_FILE) + assert result1["text"] == result2["text"], ( + f"Non-deterministic transcription:\n Run 1: {result1['text']}\n Run 2: {result2['text']}" + ) + + +if __name__ == "__main__": + print("=" * 60) + print("Whisper Large V3 Turbo - Integration Test") + print("=" * 60) + print(f"Model path: {MODEL_PATH}") + print(f"Compiled path: {COMPILED_MODEL_PATH}") + print(f"Audio file: {AUDIO_FILE}") + print(f"Dtype: {DTYPE}") + print(f"Batch size: {BATCH_SIZE}") + print(f"TP degree: {TP_DEGREE}") + print() + + config = _get_config() + + # Compile + if not os.path.exists(COMPILED_MODEL_PATH): + print("Compiling model...") + model = NeuronApplicationWhisper(MODEL_PATH, config=config) + model.compile(COMPILED_MODEL_PATH) + print("Compilation complete.\n") + + # Load + print("Loading compiled model...") + model = NeuronApplicationWhisper(COMPILED_MODEL_PATH, config=config) + model.load(COMPILED_MODEL_PATH) + print("Model loaded.\n") + + # Transcribe + if os.path.exists(AUDIO_FILE): + print(f"Transcribing: {AUDIO_FILE}") + start = time.perf_counter() + result = model.transcribe(AUDIO_FILE, verbose=True) + elapsed = time.perf_counter() - start + print(f"\nTranscription: {result['text']}") + print(f"Latency: {elapsed * 1000:.1f}ms") + + # Determinism check + result2 = model.transcribe(AUDIO_FILE) + if result["text"] == result2["text"]: + print("Determinism: PASS (identical output)") + else: + print("Determinism: FAIL (different outputs)") + else: + print(f"WARNING: Audio file not found: {AUDIO_FILE}") + print("Set WHISPER_AUDIO_FILE to run transcription tests.") + + print("\nAll tests passed.") diff --git a/contrib/models/whisper-large-v3-turbo/test/unit/__init__.py b/contrib/models/whisper-large-v3-turbo/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From d71acc5e6ffca4829525606f92a13d91301017f2 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 26 Mar 2026 23:22:38 -0400 Subject: [PATCH 2/4] Fix audio file path resolution in integration test --- .../models/whisper-large-v3-turbo/test/integration/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py b/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py index 649e3c73..1d83abca 100644 --- a/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py +++ b/contrib/models/whisper-large-v3-turbo/test/integration/test_model.py @@ -50,6 +50,7 @@ "..", "..", "..", + "..", "examples", "audio-sample.mp3", ), From dad49ff61083ff433a709e11a3f3d6b5f7b13114 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Fri, 27 Mar 2026 00:16:41 -0400 Subject: [PATCH 3/4] Add ffmpeg as prerequisite in README --- contrib/models/whisper-large-v3-turbo/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/contrib/models/whisper-large-v3-turbo/README.md b/contrib/models/whisper-large-v3-turbo/README.md index ef22f7ea..77272cfe 100644 --- a/contrib/models/whisper-large-v3-turbo/README.md +++ b/contrib/models/whisper-large-v3-turbo/README.md @@ -125,6 +125,9 @@ print(result["text"]) ### Prerequisites ```bash +# ffmpeg is required by openai-whisper for audio decoding +sudo apt-get install -y ffmpeg + pip install openai-whisper pytest ``` @@ -154,6 +157,7 @@ The integration test: ## Dependencies +- `ffmpeg` (system package, required by openai-whisper for audio decoding) - `openai-whisper` (provides base `Whisper` class and decoding loop) - `transformers` (for `WhisperModel.from_pretrained` weight loading and `sinusoids`) - `neuronx-distributed-inference` (NxDI base classes, model wrapper, config) @@ -165,4 +169,4 @@ Jim Burtoft (jimburtoft) ## Last Updated -2026-03-26 +2026-03-27 From 019e42e2c1f5c86ea322f0c1c99de5b8633dc6c4 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Fri, 3 Apr 2026 18:56:46 -0400 Subject: [PATCH 4/4] Update README: replace BS=8 metrics with validated BS=4 results --- .../models/whisper-large-v3-turbo/README.md | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/contrib/models/whisper-large-v3-turbo/README.md b/contrib/models/whisper-large-v3-turbo/README.md index 77272cfe..854024d5 100644 --- a/contrib/models/whisper-large-v3-turbo/README.md +++ b/contrib/models/whisper-large-v3-turbo/README.md @@ -44,21 +44,16 @@ This is an encoder-decoder model with separate encoder and decoder compilation. | 30.0s | 462.9ms | 64.8x | | 90.0s | 1102.2ms | 81.7x | -### Batched (BS=8, trn2.3xlarge, LNC=2, bfloat16) +### Batched (BS=4, trn2.3xlarge, LNC=2, float16) -| Audio Duration | Batch Latency | Per-Sample | Throughput | -|---------------|--------------|------------|------------| -| 5.0s | 630.2ms | 78.8ms | 12.69 audio-sec/wall-sec | -| 30.0s | 675.5ms | 84.4ms | 11.84 audio-sec/wall-sec | -| 90.0s | 675.0ms | 84.4ms | 11.85 audio-sec/wall-sec | +Measured on 7.3s audio (single 30s segment): -### Data Parallel (DP=4 x BS=8, trn2.3xlarge, LNC=2, bfloat16) +| Config | Batch Latency | Audio/s | Audio-sec/s/core | +|--------|:---:|:---:|:---:| +| BS=4, 1 core | 487ms | 8.2 | 59.9 | +| BS=4, DP=4 | 512ms/core | 31.2 | 56.9 | -| Audio Duration | Aggregate Throughput | -|---------------|---------------------| -| 5.0s | **46.65 audio-sec/wall-sec** | -| 30.0s | **43.75 audio-sec/wall-sec** | -| 90.0s | **43.27 audio-sec/wall-sec** | +BS=4 achieves **1.6x** the hardware utilization (audio-sec/s/core) of BS=1. BS=8 causes a decoder latency regression and is not recommended. ## Usage @@ -117,7 +112,7 @@ print(result["text"]) **Notes**: - TP=1 is recommended. Whisper (809M params) fits on a single NeuronCore. - Higher TP degrees are supported for head-sharding but provide no benefit for this model size. -- For maximum throughput on trn2.3xlarge, use DP=4 x BS=8 with LNC=2 (4 independent model instances). +- For maximum throughput on trn2.3xlarge, use DP=4 x BS=4 with LNC=2 (4 independent model instances). - Each batch size requires separate compilation (BS is baked into the traced graph). ## Testing