diff --git a/contrib/models/MiMo-V2-Flash/README.md b/contrib/models/MiMo-V2-Flash/README.md new file mode 100644 index 00000000..f45b1a74 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/README.md @@ -0,0 +1,211 @@ +# Contrib Model: MiMo-V2-Flash + +NeuronX Distributed Inference implementation of [XiaomiMiMo/MiMo-V2-Flash](https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash). + +## Model Information + +- **HuggingFace ID:** `XiaomiMiMo/MiMo-V2-Flash` +- **Model Type:** Decoder-only MoE transformer with hybrid attention +- **Architecture:** Custom MoE with full + sliding window attention +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Hidden Size | 4096 | +| Layers | 48 | +| Attention Heads | 64 Q | +| KV Heads (full attn) | 4 | +| KV Heads (sliding window) | 8 | +| Q/K Head Dim | 192 | +| V Head Dim | 128 | +| Experts | 256 (top-8 routing) | +| Expert Intermediate | 1536 | +| Vocab Size | 151,936 | +| RoPE | Partial (34% of dims), theta=5M | +| Sliding Window | 32,768 | +| Max Position | 262,144 | +| Total Params | ~143B (FP8) / ~286B (BF16) | + +Key features: +- **Hybrid Attention**: 9 full attention layers (0, 5, 11, 17, 23, 29, 35, 41, 47) + 39 sliding window layers +- **Asymmetric Head Dims**: Q/K use 192, V uses 128 (fused_qkv not supported) +- **Attention Sink Bias**: Learnable per-head bias in sliding window layers +- **Sigmoid Router**: For MoE expert selection +- **Expert Parallelism**: Supports EP=64 for prefill with hybrid sharding (EP=1 for token generation) + +## Prerequisites + +- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 -> 64 logical cores) +- **Weights**: BF16 format (convert from FP8 using `conversion_script/preprocess_mimo_v2_fp8.py`) + +## FP8 to BF16 Conversion + +The original model uses block-wise FP8 quantization incompatible with Neuron FP8. Convert to BF16: + +```bash +python src/neuronx_distributed_inference/models/mimo_v2/conversion_script/preprocess_mimo_v2_fp8.py \ + --input-dir /path/to/MiMo-V2-Flash \ + --output-dir /path/to/MiMo-V2-Flash-BF16 +``` + +## Usage + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + +from src.modeling_mimo_v2 import NeuronMiMoV2ForCausalLM, MiMoV2InferenceConfig + +model_path = "/path/to/MiMo-V2-Flash-BF16/" +compiled_path = "/path/to/compiled/" + +neuron_config = MoENeuronConfig( + tp_degree=64, + moe_tp_degree=1, + moe_ep_degree=64, + batch_size=1, + seq_len=512, + max_context_length=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + sequence_parallel_enabled=True, + fused_qkv=False, # Required: asymmetric Q/K vs V dims + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), + router_config={act_fn: sigmoid}, +) + +config = MiMoV2InferenceConfig( + neuron_config, load_config=load_pretrained_config(model_path) +) + +model = NeuronMiMoV2ForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +adapter = HuggingFaceGenerationAdapter(model, tokenizer) +output = adapter.generate("Hello, how are you?", max_new_tokens=128) +``` + +## vLLM Integration + +MiMo-V2-Flash can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A patch is required to add MiMo architecture support. + +### Setup + +```bash +# 1. Install vllm-neuron +pip install vllm-neuron + +# 2. Apply the MiMo/MiniMax patch +cd /path/to/vllm-neuron +git apply /path/to/neuronx-distributed-inference/perf_test/vllm-neuron-mimo-minimax.patch +pip install -e . +``` + +### Serving + +```bash +python3 -m vllm.entrypoints.openai.api_server \ + --model /path/to/MiMo-V2-Flash-BF16 \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": true, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' +``` + +### Key vLLM Patch Changes + +The patch (`perf_test/vllm-neuron-mimo-minimax.patch`) modifies vllm-neuron to: +- Map MiMo architecture to Qwen2 model loader (MiMo is Qwen2-based) +- Pass `hf_config` from vLLM to NxDI (required for `trust_remote_code` models) +- Replace `AutoModelForCausalLM.from_pretrained` with `snapshot_download` for model loading + +See `perf_test/1_bench_mimo_v2_flash.sh` for full benchmark configurations with BS=1/32/128. + +## Performance + +### Standalone NxDI (trn2.48xlarge, BF16, TP=64, EP=64) + +| Batch Size | Throughput (tok/s) | +|------------|-------------------| +| 1 | 29.92 | +| 8 | 215.94 | +| 32 | 649.14 | + +### vLLM Serving (trn2.48xlarge, BF16, BS=32, TP=64/EP=64, CB) + +Input/output: 900/90 tokens (random dataset) + +| Concurrency | Throughput (tok/s) | TPOT (ms) | TTFT (ms) | +|-------------|-------------------|-----------|-----------| +| 1 | 27.98 | 33.65 | 222 | +| 16 | 224.57 | 64.95 | 570 | +| 32 | 302.61 | 90.23 | 1351 | + +> **Note:** Large MoE models like MiMo-V2-Flash require extended engine startup time (~47 min for compile+load). Set `VLLM_ENGINE_READY_TIMEOUT_S=3600` before launching the vLLM server. + +## Compatibility Matrix + +| Instance/Version | 2.22+ (PyTorch 2.9) | 2.21 and earlier | +|------------------|---------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not supported (requires 64 cores) | Not supported | +| Inf2 | Not supported | Not supported | + +## Testing + +```bash +pytest contrib/models/MiMo-V2-Flash/test/integration/test_model.py -v +``` + +## Key Implementation Notes + +1. **Hybrid Attention**: `hybrid_layer_pattern` list determines full vs sliding window per layer. +2. **CONVERT_TO_MHA**: When TP > num_kv_heads (4), K/V are replicated to match Q heads (64). +3. **Attention Sink Bias**: Adds learnable sink column to attention weights in sliding window layers. +4. **EP Hybrid Sharding**: EP is used during prefill only; token generation uses EP=1 unless batch_size >= 32. +5. **FP8 Conversion**: Original uses OCP block-wise FP8, requires conversion to BF16 or Neuron-compatible FP8 format. + +## Example Checkpoints + +* [XiaomiMiMo/MiMo-V2-Flash](https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash) + +## Maintainer + +Henan Wan (whn09) + +**Last Updated:** 2026-04-13 diff --git a/contrib/models/MiMo-V2-Flash/src/__init__.py b/contrib/models/MiMo-V2-Flash/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py new file mode 100644 index 00000000..8b221249 --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/src/modeling_mimo_v2.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""MiMo-V2-Flash model for NXD inference - Contrib wrapper.""" + +from typing import List + +from neuronx_distributed_inference.models.config import InferenceConfig, MoENeuronConfig +from neuronx_distributed_inference.models.mimo_v2.modeling_mimo_v2 import ( + NeuronMiMoV2ForCausalLM as BaseNeuronMiMoV2ForCausalLM, + MiMoV2InferenceConfig as BaseMiMoV2InferenceConfig, + convert_mimo_v2_hf_to_neuron_state_dict, +) + + +class MiMoV2InferenceConfig(BaseMiMoV2InferenceConfig): + """Configuration class for MiMo-V2-Flash inference on NeuronX.""" + pass + + +class NeuronMiMoV2ForCausalLM(BaseNeuronMiMoV2ForCausalLM): + """MiMo-V2-Flash Causal Language Model for NeuronX inference. + + Architecture: + - 48 decoder layers with Mixture of 256 Experts (top-8) + - Hybrid attention: full (4 KV heads) + sliding window (8 KV heads) + - Asymmetric head dims: Q/K=192, V=128 + - Partial RoPE (34%), attention sink bias + - Sigmoid router + """ + + @classmethod + def get_config_cls(cls): + return MiMoV2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict, config) -> dict: + return convert_mimo_v2_hf_to_neuron_state_dict(state_dict, config) diff --git a/contrib/models/MiMo-V2-Flash/test/__init__.py b/contrib/models/MiMo-V2-Flash/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2-Flash/test/integration/__init__.py b/contrib/models/MiMo-V2-Flash/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiMo-V2-Flash/test/integration/test_model.py b/contrib/models/MiMo-V2-Flash/test/integration/test_model.py new file mode 100644 index 00000000..bcbc368e --- /dev/null +++ b/contrib/models/MiMo-V2-Flash/test/integration/test_model.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +"""Integration tests for MiMo-V2-Flash NeuronX implementation.""" + +import pytest +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +def test_config_import(): + """Test that config class can be imported.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig, NeuronMiMoV2ForCausalLM + assert MiMoV2InferenceConfig is not None + assert NeuronMiMoV2ForCausalLM is not None + print("PASS: Config and model classes imported successfully") + + +def test_required_attributes(): + """Test that required attributes are defined.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + # Check get_required_attributes without instantiation (requires many params) + required = MiMoV2InferenceConfig.get_required_attributes(MiMoV2InferenceConfig) + assert "hidden_size" in required + assert "n_routed_experts" in required + assert "num_experts_per_tok" in required + assert "hybrid_layer_pattern" in required + assert "v_head_dim" in required + assert "swa_head_dim" in required + print(f"PASS: {len(required)} required attributes defined") + + +def test_neuron_config_cls(): + """Test that MoENeuronConfig is returned.""" + from modeling_mimo_v2 import MiMoV2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + assert MiMoV2InferenceConfig.get_neuron_config_cls() == MoENeuronConfig + print("PASS: MoENeuronConfig returned") + + +def test_state_dict_converter(): + """Test that state dict converter function exists.""" + from modeling_mimo_v2 import NeuronMiMoV2ForCausalLM + assert hasattr(NeuronMiMoV2ForCausalLM, "convert_hf_to_neuron_state_dict") + print("PASS: State dict converter exists") + + +if __name__ == "__main__": + test_config_import() + test_required_attributes() + test_neuron_config_cls() + test_state_dict_converter() + print("\nAll tests passed!") diff --git a/contrib/models/MiMo-V2-Flash/test/unit/__init__.py b/contrib/models/MiMo-V2-Flash/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/README.md b/contrib/models/MiniMax-M2/README.md new file mode 100644 index 00000000..f6945482 --- /dev/null +++ b/contrib/models/MiniMax-M2/README.md @@ -0,0 +1,194 @@ +# Contrib Model: MiniMax M2 + +NeuronX Distributed Inference implementation of [MiniMax/MiniMax-M2](https://huggingface.co/MiniMax/MiniMax-M2). + +## Model Information + +- **HuggingFace ID:** `MiniMax/MiniMax-M2` +- **Model Type:** Decoder-only MoE transformer +- **Architecture:** Custom MoE with sigmoid routing and e_score_correction_bias +- **License:** Check HuggingFace model card + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Hidden Size | 3072 | +| Layers | 62 | +| Attention Heads | 48 Q / 8 KV (GQA) | +| Head Dim | 128 | +| Experts | 256 (top-8 routing) | +| Expert Intermediate | 1536 | +| MLP Intermediate | 8192 | +| Vocab Size | 200,064 | +| RoPE | Partial (50% of head_dim), theta=5M | +| Max Position | 196,608 | + +Key features: +- **QK Norm**: Applied before reshape on full projection output +- **Partial RoPE**: Only first 64 of 128 dims use rotary encoding +- **Sigmoid Router**: With learnable e_score_correction_bias for expert selection +- **fused_qkv**: Supported for efficient Q/K/V projection + +## Prerequisites + +- **Instance**: trn2.48xlarge (32 NeuronCores, logical_nc_config=2 -> 64 logical cores) +- **Weights**: BF16 format (convert from FP8 original if needed) + +## Usage + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config, HuggingFaceGenerationAdapter + +from src.modeling_minimax_m2 import NeuronMiniMaxM2ForCausalLM, MiniMaxM2InferenceConfig + +model_path = "/path/to/MiniMax-M2-BF16/" +compiled_path = "/path/to/compiled/" + +neuron_config = MoENeuronConfig( + tp_degree=64, + moe_tp_degree=64, + moe_ep_degree=1, + batch_size=1, + seq_len=512, + max_context_length=256, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + sequence_parallel_enabled=True, + fused_qkv=True, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=True, temperature=0.6, top_k=20, top_p=0.95 + ), + router_config={act_fn: sigmoid}, +) + +config = MiniMaxM2InferenceConfig( + neuron_config, load_config=load_pretrained_config(model_path) +) + +model = NeuronMiniMaxM2ForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path) +adapter = HuggingFaceGenerationAdapter(model, tokenizer) +output = adapter.generate("Hello, how are you?", max_new_tokens=128) +``` + +## vLLM Integration + +MiniMax-M2 can be served via [vllm-neuron](https://github.com/aws-neuron/vllm-neuron). A patch is required to add MiniMax architecture support. + +### Setup + +```bash +# 1. Install vllm-neuron +pip install vllm-neuron + +# 2. Apply the MiMo/MiniMax patch +cd /path/to/vllm-neuron +git apply /path/to/neuronx-distributed-inference/perf_test/vllm-neuron-mimo-minimax.patch +pip install -e . +``` + +### Serving + +```bash +python3 -m vllm.entrypoints.openai.api_server \ + --model /path/to/MiniMax-M2-BF16 \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 256 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 64, + "logical_nc_config": 2, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": true, + "glu_mlp": true, + "moe_mask_padded_tokens": true, + "disable_numeric_cc_token": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}, + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 256, + "ctx_batch_size": 1, + "tkg_batch_size": 256, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "fused_qkv": false, + "enable_bucketing": true, + "normalize_top_k_affinities": true, + "use_index_calc_kernel": true, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": true, + "skip_dma_token": true + }, + "scratchpad_page_size": 1024 + } + }' +``` + +### Key vLLM Patch Changes + +The patch (`perf_test/vllm-neuron-mimo-minimax.patch`) modifies vllm-neuron to: +- Pass `hf_config` from vLLM to NxDI (required for `trust_remote_code` models) +- Replace `AutoModelForCausalLM.from_pretrained` with `snapshot_download` for model loading + +See `perf_test/2_bench_minimax_m2.sh` for full benchmark configurations with BS=1/256. + +## Performance + +### vLLM Serving — Config 1 (trn2.48xlarge, BF16, BS=1, TP=64/EP=1, non-CB) + +Input/output: 900/90 tokens (random dataset) + +| Concurrency | Throughput (tok/s) | TPOT (ms) | TTFT (ms) | +|-------------|-------------------|-----------|-----------| +| 1 | 39.28 | 13.56 | 1088 | + +### vLLM Serving — Config 2 (trn2.48xlarge, BF16, BS=256, TP=64/EP=64, CB) + +Input/output: 900/90 tokens (random dataset) + +| Concurrency | Throughput (tok/s) | TPOT (ms) | TTFT (ms) | +|-------------|-------------------|-----------|-----------| +| 1 | 5.76 | 173.83 | 165 | +| 16 | 54.69 | 287.09 | 513 | +| 32 | 75.85 | 408.66 | 1066 | +| 128 | 106.72 | 1158.08 | 3950 | +| 256 | 128.94 | 1860.69 | 11263 | + +> **Note:** Large MoE models like MiniMax-M2 require extended engine startup time. Set `VLLM_ENGINE_READY_TIMEOUT_S=3600` before launching the vLLM server. + +## Compatibility Matrix + +| Instance/Version | 2.22+ (PyTorch 2.9) | 2.21 and earlier | +|------------------|---------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not supported (requires 64 cores) | Not supported | +| Inf2 | Not supported | Not supported | + +## Testing + +```bash +pytest contrib/models/MiniMax-M2/test/integration/test_model.py -v +``` + +## Example Checkpoints + +* [MiniMax/MiniMax-M2](https://huggingface.co/MiniMax/MiniMax-M2) +* [MiniMax/MiniMax-M2-unquantized](https://huggingface.co/MiniMax/MiniMax-M2-unquantized) (BF16) + +## Maintainer + +Henan Wan (whn09) + +**Last Updated:** 2026-04-13 diff --git a/contrib/models/MiniMax-M2/src/__init__.py b/contrib/models/MiniMax-M2/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/src/modeling_minimax_m2.py b/contrib/models/MiniMax-M2/src/modeling_minimax_m2.py new file mode 100644 index 00000000..1fd5d31c --- /dev/null +++ b/contrib/models/MiniMax-M2/src/modeling_minimax_m2.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""MiniMax M2 model for NXD inference - Contrib wrapper.""" + +from neuronx_distributed_inference.models.minimax_m2.modeling_minimax_m2 import ( + NeuronMiniMaxM2ForCausalLM, + MiniMaxM2InferenceConfig, + convert_minimax_m2_hf_to_neuron_state_dict, +) + +__all__ = [ + "MiniMaxM2InferenceConfig", + "NeuronMiniMaxM2ForCausalLM", +] diff --git a/contrib/models/MiniMax-M2/test/__init__.py b/contrib/models/MiniMax-M2/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/test/integration/__init__.py b/contrib/models/MiniMax-M2/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/MiniMax-M2/test/integration/test_model.py b/contrib/models/MiniMax-M2/test/integration/test_model.py new file mode 100644 index 00000000..a5cc87ea --- /dev/null +++ b/contrib/models/MiniMax-M2/test/integration/test_model.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +"""Integration tests for MiniMax M2 NeuronX implementation.""" + +import pytest +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + + +def test_config_import(): + """Test that config class can be imported.""" + from modeling_minimax_m2 import MiniMaxM2InferenceConfig, NeuronMiniMaxM2ForCausalLM + assert MiniMaxM2InferenceConfig is not None + assert NeuronMiniMaxM2ForCausalLM is not None + print("PASS: Config and model classes imported successfully") + + +def test_required_attributes(): + """Test that required attributes are defined.""" + from modeling_minimax_m2 import MiniMaxM2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from transformers import AutoConfig + import torch + + neuron_config = MoENeuronConfig( + tp_degree=64, + batch_size=1, + seq_len=512, + torch_dtype=torch.bfloat16, + on_cpu=True, + ) + # Use the bundled config.json to provide model-specific attributes + repo_root = Path(__file__).resolve().parent.parent.parent.parent.parent.parent + config_path = repo_root / "src" / "neuronx_distributed_inference" / "models" / "minimax_m2" + hf_config = AutoConfig.from_pretrained(str(config_path), trust_remote_code=True) + config = MiniMaxM2InferenceConfig(neuron_config, load_config=load_pretrained_config(hf_config=hf_config)) + required = config.get_required_attributes() + assert "hidden_size" in required + assert "num_local_experts" in required + assert "num_experts_per_tok" in required + print(f"PASS: {len(required)} required attributes defined") + + +def test_neuron_config_cls(): + """Test that MoENeuronConfig is returned.""" + from modeling_minimax_m2 import MiniMaxM2InferenceConfig + from neuronx_distributed_inference.models.config import MoENeuronConfig + assert MiniMaxM2InferenceConfig.get_neuron_config_cls() == MoENeuronConfig + print("PASS: MoENeuronConfig returned") + + +if __name__ == "__main__": + test_config_import() + test_required_attributes() + test_neuron_config_cls() + print("\nAll tests passed!") diff --git a/contrib/models/MiniMax-M2/test/unit/__init__.py b/contrib/models/MiniMax-M2/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/perf_test/0_setup.sh b/perf_test/0_setup.sh new file mode 100644 index 00000000..01a9ca34 --- /dev/null +++ b/perf_test/0_setup.sh @@ -0,0 +1,63 @@ +#!/bin/bash +set -e + +echo "==========================================" +echo "Setup: vllm-neuron + model weights" +echo "==========================================" + +# --- 1. Install vllm-neuron from fork with MiMo support --- +echo "" +echo "[1/3] Installing vllm-neuron (feature/mimo-support branch)..." + +# Use the NxDI venv as base +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Clone and install vllm-neuron +if [ ! -d /tmp/vllm-neuron ]; then + git clone --branch feature/mimo-support https://github.com/whn09/vllm-neuron.git /tmp/vllm-neuron +fi +cd /tmp/vllm-neuron +pip install --extra-index-url=https://pip.repos.neuron.amazonaws.com -e . +pip install s5cmd + +# Verify installation +python3 -c "import vllm_neuron; print('vllm-neuron installed:', vllm_neuron.__file__)" +vllm --version 2>/dev/null || echo "vllm CLI check done" + +# --- 2. Download MiMo-V2-Flash BF16 weights --- +echo "" +echo "[2/3] Downloading MiMo-V2-Flash BF16 weights..." + +MIMO_PATH="/opt/dlami/nvme/models/MiMo-V2-Flash-BF16" +if [ -d "$MIMO_PATH" ] && [ "$(ls $MIMO_PATH/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " MiMo weights already exist at $MIMO_PATH, skipping download" +else + echo " Downloading from s3://datalab/xiaomi/models/MiMo-V2-Flash-BF16/ ..." + mkdir -p "$MIMO_PATH" + s5cmd cp "s3://datalab/xiaomi/models/MiMo-V2-Flash-BF16/**" "$MIMO_PATH/" + echo " Download complete: $(du -sh $MIMO_PATH | cut -f1)" +fi + +# --- 3. Verify MiniMax-M2 BF16 weights --- +echo "" +echo "[3/3] Verifying MiniMax-M2 BF16 weights..." + +MINIMAX_PATH="/opt/dlami/nvme/models/MiniMax-M2-BF16" +if [ -d "$MINIMAX_PATH" ] && [ "$(ls $MINIMAX_PATH/*.safetensors 2>/dev/null | wc -l)" -gt 0 ]; then + echo " MiniMax weights exist at $MINIMAX_PATH: $(du -sh $MINIMAX_PATH | cut -f1)" +else + echo " Downloading from s3://datalab/minimax/model_hf/MiniMax-M2-BF16/ ..." + mkdir -p "$MINIMAX_PATH" + s5cmd cp "s3://datalab/minimax/model_hf/MiniMax-M2-BF16/**" "$MINIMAX_PATH/" + echo " Download complete: $(du -sh $MINIMAX_PATH | cut -f1)" +fi + +# --- Summary --- +echo "" +echo "==========================================" +echo "Setup complete!" +echo "==========================================" +echo " vllm-neuron: /tmp/vllm-neuron (feature/mimo-support)" +echo " MiMo weights: $MIMO_PATH" +echo " MiniMax weights: $MINIMAX_PATH" +echo " Disk usage: $(df -h /opt/dlami/nvme | tail -1 | awk '{print $3, "used /", $2, "(" $5 ")"}')" diff --git a/perf_test/1_bench_mimo_v2_flash.sh b/perf_test/1_bench_mimo_v2_flash.sh new file mode 100644 index 00000000..0f4fdd4d --- /dev/null +++ b/perf_test/1_bench_mimo_v2_flash.sh @@ -0,0 +1,239 @@ +#!/bin/bash +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="/opt/dlami/nvme/models/MiMo-V2-Flash-BF16" +PORT=8000 +RESULTS_DIR="/tmp/bench_results/mimo_v2_flash" +mkdir -p "$RESULTS_DIR" + +# Common neuron config shared across all MiMo configs +COMMON_MIMO_CONFIG='"tp_degree": 64, + "logical_nc_config": 2, + "fused_qkv": false, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": true, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "qkv_cte_nki_kernel_fuse_rope": false, + "attn_kernel_enabled": false, + "strided_context_parallel_kernel_enabled": false, + "glu_mlp": true, + "normalize_top_k_affinities": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}' + +# Helper: wait for vLLM server to be ready +wait_for_server() { + echo " Waiting for vLLM server to be ready..." + for i in $(seq 1 120); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (${i}s)" + return 0 + fi + sleep 5 + done + echo " ERROR: Server did not start within 600s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "MiMo-V2-Flash Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=1, TP=64/EP=1, non-CB (baseline latency) +############################################################################### +CONFIG_NAME="bs1_tp64_ep1" +echo "--- Config 1: BS=1, TP=64/EP=1, non-CB (baseline) ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 1 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 64, + "moe_ep_degree": 1, + "batch_size": 1, + "ctx_batch_size": 1, + "tkg_batch_size": 1, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": false, + "enable_bucketing": false, + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +stop_server + +############################################################################### +# Config 2: BS=32, TP=1/EP=64, CB + optimizations +############################################################################### +CONFIG_NAME="bs32_tp1_ep64_opt" +echo "--- Config 2: BS=32, TP=1/EP=64, CB + optimizations ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 32 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 32, + "ctx_batch_size": 1, + "tkg_batch_size": 32, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + }, + "use_index_calc_kernel": true, + "moe_mask_padded_tokens": true, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": true, + "skip_dma_token": true + }, + "disable_numeric_cc_token": true, + "scratchpad_page_size": 1024 + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +stop_server + +############################################################################### +# Config 3: BS=128, TP=1/EP=64, CB + optimizations +############################################################################### +CONFIG_NAME="bs128_tp1_ep64_opt" +echo "--- Config 3: BS=128, TP=1/EP=64, CB + optimizations ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 128 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MIMO_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 128, + "ctx_batch_size": 1, + "tkg_batch_size": 128, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + }, + "use_index_calc_kernel": true, + "moe_mask_padded_tokens": true, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": true, + "skip_dma_token": true + }, + "disable_numeric_cc_token": true, + "scratchpad_page_size": 1024 + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +run_bench "$CONFIG_NAME" 128 512 +stop_server + +echo "==========================================" +echo "MiMo-V2-Flash benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/perf_test/2_bench_minimax_m2.sh b/perf_test/2_bench_minimax_m2.sh new file mode 100644 index 00000000..e979951c --- /dev/null +++ b/perf_test/2_bench_minimax_m2.sh @@ -0,0 +1,195 @@ +#!/bin/bash +set -e + +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +MODEL_PATH="/opt/dlami/nvme/models/MiniMax-M2-BF16" +PORT=8000 +RESULTS_DIR="/tmp/bench_results/minimax_m2" +mkdir -p "$RESULTS_DIR" + +# Common neuron config shared across all MiniMax configs +COMMON_MINIMAX_CONFIG='"tp_degree": 64, + "logical_nc_config": 2, + "flash_decoding_enabled": false, + "sequence_parallel_enabled": true, + "qkv_kernel_enabled": false, + "qkv_nki_kernel_enabled": false, + "attn_kernel_enabled": false, + "glu_mlp": true, + "moe_mask_padded_tokens": true, + "disable_numeric_cc_token": true, + "router_config": {"act_fn": "sigmoid", "dtype": "float32"}' + +# Helper: wait for vLLM server to be ready +wait_for_server() { + echo " Waiting for vLLM server to be ready..." + for i in $(seq 1 120); do + if curl -s http://localhost:$PORT/health > /dev/null 2>&1; then + echo " Server ready! (${i}s)" + return 0 + fi + sleep 5 + done + echo " ERROR: Server did not start within 600s" + return 1 +} + +# Helper: run benchmark +run_bench() { + local config_name=$1 + local concurrency=$2 + local num_prompts=$3 + + echo " Benchmark: concurrency=$concurrency, prompts=$num_prompts" + vllm bench serve \ + --backend vllm \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --endpoint /v1/completions \ + --dataset-name random \ + --num-prompts "$num_prompts" \ + --random-input-len 900 \ + --random-output-len 90 \ + --random-range-ratio 0.03 \ + --max-concurrency "$concurrency" \ + 2>&1 | tee "$RESULTS_DIR/${config_name}_c${concurrency}.txt" + echo "" +} + +# Helper: stop server +stop_server() { + echo " Stopping vLLM server..." + pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 5 +} + +# Helper: quick sanity check +sanity_check() { + echo " Running sanity check..." + curl -s http://localhost:$PORT/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "messages": [{"role": "user", "content": "What is 1+1? Answer briefly."}], + "model": "'"$MODEL_PATH"'", + "max_tokens": 64, + "temperature": 0.0, + "stream": false + }' | python3 -c "import sys,json; r=json.load(sys.stdin); print(' Sanity:', r['choices'][0]['message']['content'][:100])" 2>/dev/null || echo " Sanity check: could not parse response" +} + +echo "==========================================" +echo "MiniMax-M2 Performance Benchmark" +echo "==========================================" +echo "Model: $MODEL_PATH" +echo "Results: $RESULTS_DIR" +echo "" + +############################################################################### +# Config 1: BS=1, TP=64/EP=1, non-CB (baseline latency) +# NOTE: fused_qkv=true, use_shard_on_intermediate=false (avoids 10.7x padding) +############################################################################### +CONFIG_NAME="bs1_tp64_ep1" +echo "--- Config 1: BS=1, TP=64/EP=1, non-CB (baseline) ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 1 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MINIMAX_CONFIG"', + "moe_tp_degree": 64, + "moe_ep_degree": 1, + "batch_size": 1, + "ctx_batch_size": 1, + "tkg_batch_size": 1, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": false, + "fused_qkv": true, + "enable_bucketing": false, + "async_mode": false, + "use_index_calc_kernel": false, + "normalize_top_k_affinities": true, + "on_device_sampling_config": { + "do_sample": true, "temperature": 0.6, "top_k": 20, "top_p": 0.95 + }, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": false, + "skip_dma_token": true + } + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +stop_server + +############################################################################### +# Config 2: BS=256, TP=1/EP=64, CB + optimizations +# NOTE: With EP=64, I_TP=1536/1=1536, 1536%256=0, so shard_on_intermediate is safe +############################################################################### +CONFIG_NAME="bs256_tp1_ep64_opt" +echo "--- Config 2: BS=256, TP=1/EP=64, CB + optimizations ---" + +python3 -m vllm.entrypoints.openai.api_server \ + --model "$MODEL_PATH" \ + --tokenizer "$MODEL_PATH" \ + --tensor-parallel-size 64 \ + --max-model-len 1024 \ + --max-num-seqs 256 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --port $PORT \ + --trust_remote_code \ + --additional-config '{ + "override_neuron_config": { + '"$COMMON_MINIMAX_CONFIG"', + "moe_tp_degree": 1, + "moe_ep_degree": 64, + "batch_size": 256, + "ctx_batch_size": 1, + "tkg_batch_size": 256, + "max_context_length": 1024, + "seq_len": 1024, + "is_continuous_batching": true, + "fused_qkv": false, + "enable_bucketing": true, + "context_encoding_buckets": [1024], + "token_generation_buckets": [1024], + "async_mode": false, + "normalize_top_k_affinities": true, + "strided_context_parallel_kernel_enabled": false, + "qkv_cte_nki_kernel_fuse_rope": false, + "on_device_sampling_config": null, + "use_index_calc_kernel": true, + "blockwise_matmul_config": { + "use_shard_on_intermediate_dynamic_while": true, + "skip_dma_token": true + }, + "scratchpad_page_size": 1024 + } + }' & + +wait_for_server +sanity_check +run_bench "$CONFIG_NAME" 1 16 +run_bench "$CONFIG_NAME" 16 128 +run_bench "$CONFIG_NAME" 32 128 +run_bench "$CONFIG_NAME" 128 512 +run_bench "$CONFIG_NAME" 256 512 +stop_server + +echo "==========================================" +echo "MiniMax-M2 benchmarks complete!" +echo "Results saved to: $RESULTS_DIR" +echo "==========================================" +ls -la "$RESULTS_DIR" diff --git a/perf_test/README.md b/perf_test/README.md new file mode 100644 index 00000000..e141f652 --- /dev/null +++ b/perf_test/README.md @@ -0,0 +1,57 @@ +# Performance Test Plan for PR #119 (MiMo-V2-Flash & MiniMax-M2) + +## Overview + +Use vllm-neuron to benchmark both models on trn2.48xlarge with various batch sizes and parallelism configs. + +## Prerequisites + +- **Instance**: trn2.48xlarge (32 NeuronCores, 2TB RAM, 1.7TB NVMe) +- **vllm-neuron**: Fork with MiMo support (https://github.com/whn09/vllm-neuron/tree/feature/mimo-support) +- **NxDI**: PR #119 branch installed at `/tmp/nxdi-fork-main/` +- **Model weights** (BF16, from S3): + - MiMo-V2-Flash: `s3://datalab/xiaomi/models/MiMo-V2-Flash-BF16/` + - MiniMax-M2: `s3://datalab/minimax/model_hf/MiniMax-M2-BF16/` (already downloaded) + +## Test Configurations + +### MiMo-V2-Flash + +| Config | BS | TP | EP | CB | Optimizations | Benchmark Concurrency | +|--------|----|----|----|----|---------------|----------------------| +| 1 | 1 | 64 | 1 | No | baseline | 1 | +| 2 | 32 | 1 | 64 | Yes | index_calc, blockwise, scratchpad | 1, 16, 32 | +| 3 | 128 | 1 | 64 | Yes | index_calc, blockwise, scratchpad | 1, 16, 32, 128 | + +### MiniMax-M2 + +| Config | BS | TP | EP | CB | Optimizations | Benchmark Concurrency | +|--------|----|----|----|----|---------------|----------------------| +| 1 | 1 | 64 | 1 | No | baseline, fused_qkv=true | 1 | +| 2 | 256 | 1 | 64 | Yes | index_calc, blockwise, scratchpad | 1, 16, 32, 128, 256 | + +### Benchmark Parameters +- Dataset: random +- Input length: 900 tokens +- Output length: 90 tokens +- Range ratio: 0.03 +- Prompts per run: 16 (concurrency=1), 128 (concurrency 16/32), 512 (concurrency 128/256) + +## Scripts + +1. `0_setup.sh` - Install vllm-neuron, download model weights +2. `1_bench_mimo_v2_flash.sh` - MiMo-V2-Flash benchmark (all configs) +3. `2_bench_minimax_m2.sh` - MiniMax-M2 benchmark (all configs) + +## Execution + +```bash +# Step 1: Setup (one-time) +bash 0_setup.sh + +# Step 2: Run MiMo benchmarks +bash 1_bench_mimo_v2_flash.sh 2>&1 | tee /tmp/mimo_bench_results.log + +# Step 3: Run MiniMax benchmarks +bash 2_bench_minimax_m2.sh 2>&1 | tee /tmp/minimax_bench_results.log +``` diff --git a/perf_test/vllm-neuron-mimo-minimax.patch b/perf_test/vllm-neuron-mimo-minimax.patch new file mode 100644 index 00000000..cb8c0421 --- /dev/null +++ b/perf_test/vllm-neuron-mimo-minimax.patch @@ -0,0 +1,129 @@ +diff --git a/vllm_neuron/worker/neuronx_distributed_model_loader.py b/vllm_neuron/worker/neuronx_distributed_model_loader.py +index d2099eb..e246249 100644 +--- a/vllm_neuron/worker/neuronx_distributed_model_loader.py ++++ b/vllm_neuron/worker/neuronx_distributed_model_loader.py +@@ -41,7 +41,7 @@ from neuronx_distributed_inference.models.config import ( # yapf: disable + from neuronx_distributed_inference.modules.lora_serving import LoraServingConfig + from neuronx_distributed_inference.utils.constants import MODEL_TYPES + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +-from transformers import AutoModelForCausalLM, PretrainedConfig ++from transformers import PretrainedConfig + from vllm.config import ( + CacheConfig, + ModelConfig, +@@ -186,8 +186,14 @@ class NeuronModelBase(nn.Module): + + neuron_config = neuronx_model_cls.get_neuron_config_cls()(**neuron_config_dict) + ++ # Use pre-loaded hf_config if available (loaded by vLLM with trust_remote_code=True) ++ hf_config = kwargs.get("hf_config") ++ if hf_config is not None: ++ load_config_fn = load_pretrained_config(hf_config=hf_config) ++ else: ++ load_config_fn = load_pretrained_config(model_name_or_path) + config = kwargs.get("config") or neuronx_model_cls.get_config_cls()( +- neuron_config, load_config=load_pretrained_config(model_name_or_path) ++ neuron_config, load_config=load_config_fn + ) + + # If fused speculation is enabled, attach the draft model config. +@@ -254,11 +260,10 @@ class NeuronModelBase(nn.Module): + "Using pre-compiled artifacts, override_neuron_config will be ignored" + ) + +- def _save_pretrained_model(self, model_name: str): +- hf_model = AutoModelForCausalLM.from_pretrained(model_name) +- saved_path = os.path.join("local-models", model_name) +- hf_model.save_pretrained(saved_path) +- return saved_path ++ def _get_model_path(self, model_name: str): ++ """Get local path for model, using HuggingFace cache if available.""" ++ from huggingface_hub import snapshot_download ++ return snapshot_download(repo_id=model_name, trust_remote_code=True) + + def _compile_and_load_model( + self, model_path: str, neuronx_model_cls, config, compiled_path: str +@@ -565,7 +570,7 @@ class NeuronCausalLM(NeuronModelBase): + + if not success: + if not os.path.exists(model_name_or_path): +- model_name_or_path = self._save_pretrained_model(model_name_or_path) ++ model_name_or_path = self._get_model_path(model_name_or_path) + self._compile_and_load_model( + model_name_or_path, neuronx_model_cls, config, compiled_model_path + ) +@@ -611,10 +616,15 @@ class NeuronMultiModalCausalLM(NeuronCausalLM): + **text_neuron_config + ) + ++ hf_config = kwargs.get("hf_config") ++ if hf_config is not None: ++ load_config_fn = load_pretrained_config(hf_config=hf_config) ++ else: ++ load_config_fn = load_pretrained_config(model_name_or_path) + config = neuronx_model_cls.get_config_cls()( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, +- load_config=load_pretrained_config(model_name_or_path), ++ load_config=load_config_fn, + ) + + success, compiled_model_path, _ = self._load_weights_common( +@@ -623,7 +633,7 @@ class NeuronMultiModalCausalLM(NeuronCausalLM): + + if not success: + if not os.path.exists(model_name_or_path): +- model_name_or_path = self._save_pretrained_model(model_name_or_path) ++ model_name_or_path = self._get_model_path(model_name_or_path) + + self._compile_and_load_model( + model_name_or_path, neuronx_model_cls, config, compiled_model_path +@@ -758,14 +768,6 @@ class NeuronPixtralForCausalLM(NeuronMultiModalCausalLM): + + + class NeuronQwen2VLForCausalLM(NeuronMultiModalCausalLM): +- # overwrite _save_pretrained_model as Qwen2VL is not in AutoModelForCausalLM +- def _save_pretrained_model(self, model_name: str): +- from transformers import Qwen2VLForConditionalGeneration +- +- hf_model = Qwen2VLForConditionalGeneration.from_pretrained(model_name) +- saved_path = os.path.join("local-models", model_name) +- hf_model.save_pretrained(saved_path) +- return saved_path + + def execute_model(self, model_input): + """Helper to run model with defaults for missing multimodal inputs.""" +@@ -819,13 +821,7 @@ class NeuronQwen2VLForCausalLM(NeuronMultiModalCausalLM): + + + class NeuronQwen3VLForCausalLM(NeuronQwen2VLForCausalLM): +- def _save_pretrained_model(self, model_name: str): +- from transformers import Qwen3VLForConditionalGeneration +- +- hf_model = Qwen3VLForConditionalGeneration.from_pretrained(model_name) +- saved_path = os.path.join("local-models", model_name) +- hf_model.save_pretrained(saved_path) +- return saved_path ++ pass + + + class NeuronLlama4ForCausalLM(NeuronMultiModalCausalLM): +@@ -964,6 +960,10 @@ def _get_neuron_model_cls(architecture: str): + if model == "qwen3moe": + model = "qwen3_moe" + ++ # MiMo is based on Qwen2 architecture ++ if model == "mimo": ++ model = "qwen2" ++ + if model == "qwen2vl": + model = "qwen2_vl" + +@@ -1050,6 +1050,7 @@ def get_neuron_model( + neuron_config=neuron_config, + override_neuron_config=override_neuron_config, + speculative_config=speculative_config, ++ hf_config=model_config.hf_config, + ) + model.neuron_config = model.model.config.neuron_config + model.architecture = architecture diff --git a/src/neuronx_distributed_inference/models/mimo_v2/__init__.py b/src/neuronx_distributed_inference/models/mimo_v2/__init__.py new file mode 100644 index 00000000..935f8f82 --- /dev/null +++ b/src/neuronx_distributed_inference/models/mimo_v2/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from neuronx_distributed_inference.models.mimo_v2.modeling_mimo_v2 import ( + MiMoV2InferenceConfig, + NeuronMiMoV2ForCausalLM, +) + +__all__ = [ + "MiMoV2InferenceConfig", + "NeuronMiMoV2ForCausalLM", +] diff --git a/src/neuronx_distributed_inference/models/mimo_v2/conversion_script/preprocess_mimo_v2_fp8.py b/src/neuronx_distributed_inference/models/mimo_v2/conversion_script/preprocess_mimo_v2_fp8.py new file mode 100644 index 00000000..e96258f0 --- /dev/null +++ b/src/neuronx_distributed_inference/models/mimo_v2/conversion_script/preprocess_mimo_v2_fp8.py @@ -0,0 +1,630 @@ +""" +Preprocess MiMo-V2-Flash FP8 checkpoint for Neuron inference. + +The HuggingFace FP8 checkpoint cannot be directly used for inference on Neuron. +This script preprocesses the checkpoint to make it compatible. + +Steps: +1. Rescale FP8 weights from OCP format (range ±448) to Neuron format (range ±240) +2. Convert weight_scale_inv to .scale format (reciprocal + rescaling) +3. Fuse gate/up projections for MoE experts +4. Handle K/V weight and scale replication for CONVERT_TO_MHA mode +5. Save to preprocessed checkpoint directory + +Usage: + python preprocess_mimo_v2_fp8.py \ + --hf_model_path /path/to/MiMo-V2-Flash \ + --save_path /path/to/preprocessed_mimo_v2_fp8 \ + --tp_degree 32 \ + --convert_to_mha +""" + +import argparse +import gc +import json +import os +from typing import Dict, Any, List, Optional + +import torch + +from neuronx_distributed_inference.modules.checkpoint import ( + load_state_dict, + save_state_dict_safetensors, +) + + +# FP8 range difference between OCP (HuggingFace) and Neuron (IEEE-754) +# OCP FP8 E4M3/e4m3fn: range ±448 +# Neuron FP8 E4M3 (IEEE-754): range ±240 +FP8_SCALING_FACTOR = 448.0 / 240.0 + +# Neuron FP8 E4M3 max value +NEURON_FP8_MAX = 240.0 + + +def convert_bf16_to_fp8_per_row(weight: torch.Tensor): + """ + Convert BF16 weight to FP8 with per-row (per-channel) scales for Neuron. + + This is used for weights like o_proj that are BF16 in the original checkpoint. + The Neuron framework expects per-row scaling for these layers. + + Args: + weight: BF16 weight tensor [out_features, in_features] + + Returns: + Tuple of (fp8_weight, scale) + - fp8_weight: Weight quantized to FP8 (float8_e4m3fn) + - scale: Per-row scale tensor [out_features, 1] + """ + out_features, in_features = weight.shape + + # Compute per-row max absolute values + weight_float = weight.float() + row_max_abs = weight_float.abs().max(dim=1, keepdim=True)[0] + + # Compute scales (avoid division by zero) + scales = row_max_abs / NEURON_FP8_MAX + scales = torch.clamp(scales, min=1e-10) + + # Quantize + quantized = (weight_float / scales).to(torch.float8_e4m3fn) + + return quantized, scales.to(torch.float32) + + +def convert_bf16_to_fp8_blockwise( + weight: torch.Tensor, + block_size: List[int] = [128, 128], +): + """ + Convert BF16 weight to FP8 with block-wise scales for Neuron. + + Some weights in MiMo-V2-Flash (like o_proj) are in BF16, not FP8. + This function quantizes them to FP8 with appropriate block-wise scales. + + Args: + weight: BF16 weight tensor [out_features, in_features] + block_size: Block size for quantization [128, 128] + + Returns: + Tuple of (fp8_weight, scale) + - fp8_weight: Weight quantized to FP8 (float8_e4m3fn) + - scale: Block-wise scale tensor [scale_h, scale_w] + """ + h, w = weight.shape + block_h, block_w = block_size + + # Calculate scale grid dimensions + scale_h = (h + block_h - 1) // block_h + scale_w = (w + block_w - 1) // block_w + + # Initialize output tensors + fp8_weight = torch.zeros_like(weight, dtype=torch.float8_e4m3fn) + scale = torch.zeros(scale_h, scale_w, dtype=torch.float32) + + # Process each block + for i in range(scale_h): + for j in range(scale_w): + # Block boundaries + h_start = i * block_h + h_end = min((i + 1) * block_h, h) + w_start = j * block_w + w_end = min((j + 1) * block_w, w) + + # Extract block + block = weight[h_start:h_end, w_start:w_end].float() + + # Compute scale: max_abs / FP8_MAX + max_abs = block.abs().max().item() + if max_abs == 0: + block_scale = 1.0 + else: + block_scale = max_abs / NEURON_FP8_MAX + + # Quantize block + quantized_block = (block / block_scale).to(torch.float8_e4m3fn) + + # Store results + fp8_weight[h_start:h_end, w_start:w_end] = quantized_block + scale[i, j] = block_scale + + return fp8_weight, scale + + +def rescale_fp8_to_per_row(weight: torch.Tensor, scale: torch.Tensor): + """ + Rescale FP8 weight from OCP format to Neuron format with per-row scaling. + + The original HuggingFace checkpoint uses block-wise FP8 quantization. + The Neuron framework expects per-row (per-channel) scaling. + This function converts block-wise to per-row scaling. + + Args: + weight: FP8 weight tensor (float8_e4m3fn) [out_features, in_features] + scale: Block-wise scale tensor (weight_scale_inv) [scale_h, scale_w] + + Returns: + Tuple of (rescaled_weight, neuron_scale) + - rescaled_weight: FP8 weight compatible with Neuron + - neuron_scale: Per-row scale [out_features, 1] + """ + out_features, in_features = weight.shape + scale_h, scale_w = scale.shape + + # Block size inferred from scale dimensions + block_h = (out_features + scale_h - 1) // scale_h + block_w = (in_features + scale_w - 1) // scale_w + + # First dequantize using block-wise scales + # HF convention: original = fp8_weight * weight_scale_inv + weight_float = weight.float() + dequantized = torch.zeros(out_features, in_features, dtype=torch.float32) + + for i in range(scale_h): + for j in range(scale_w): + h_start = i * block_h + h_end = min((i + 1) * block_h, out_features) + w_start = j * block_w + w_end = min((j + 1) * block_w, in_features) + + block_scale = scale[i, j].item() + dequantized[h_start:h_end, w_start:w_end] = ( + weight_float[h_start:h_end, w_start:w_end] * block_scale + ) + + # Now requantize with per-row scaling for Neuron + # Compute per-row max absolute values + row_max_abs = dequantized.abs().max(dim=1, keepdim=True)[0] + + # Compute scales (avoid division by zero) + # Need to fit in Neuron FP8 range (±240) + scales = row_max_abs / NEURON_FP8_MAX + scales = torch.clamp(scales, min=1e-10) + + # Quantize to FP8 + quantized = (dequantized / scales).to(torch.float8_e4m3fn) + + return quantized, scales.to(torch.float32) + + +def rescale_fp8_weight_blockwise(weight: torch.Tensor, scale: torch.Tensor): + """ + Rescale FP8 weight from OCP format to Neuron format, keeping block-wise scaling. + + This is kept for MoE experts which may use block-wise scaling. + + Args: + weight: FP8 weight tensor (float8_e4m3fn) + scale: Scale tensor (float32 or bfloat16), this is weight_scale_inv (1/scale) + + Returns: + Tuple of (rescaled_weight, neuron_scale) + - rescaled_weight: FP8 weight compatible with Neuron + - neuron_scale: Scale in Neuron format (direct scale, not reciprocal) + """ + # Convert weight to BF16 for rescaling + weight_bf16 = weight.bfloat16() + + # Divide by scaling factor to fit in Neuron's smaller range + rescaled_weight_bf16 = weight_bf16 / FP8_SCALING_FACTOR + + # Convert back to FP8 + rescaled_weight = rescaled_weight_bf16.to(torch.float8_e4m3fn) + + # After our rescaling: + # rescaled_weight = fp8_weight / FP8_SCALING_FACTOR + # We need: original = rescaled_weight * new_scale + # So: original = (fp8_weight / FP8_SCALING_FACTOR) * new_scale = fp8_weight * weight_scale_inv + # Therefore: new_scale = weight_scale_inv * FP8_SCALING_FACTOR + + neuron_scale = scale.float() * FP8_SCALING_FACTOR + + return rescaled_weight, neuron_scale.to(torch.float32) + + +def replicate_for_convert_to_mha( + weight: torch.Tensor, + scale: Optional[torch.Tensor], + num_kv_heads: int, + num_attention_heads: int, + head_dim: int, +): + """ + Replicate K/V weights and per-row scales for CONVERT_TO_MHA mode. + + When TP > num_kv_heads, we need to replicate K/V heads to match Q heads. + This uses repeat_interleave to create the correct GQA pattern. + + Args: + weight: FP8 K or V weight [num_kv_heads * head_dim, hidden_size] + scale: Per-row scale tensor [num_kv_heads * head_dim, 1] + num_kv_heads: Original number of KV heads + num_attention_heads: Target number of attention heads + head_dim: Dimension per head + + Returns: + Tuple of (replicated_weight, replicated_scale) + """ + if num_kv_heads >= num_attention_heads: + return weight, scale + + repeat_factor = num_attention_heads // num_kv_heads + + # Reshape weight to [num_kv_heads, head_dim, hidden_size] + weight_reshaped = weight.view(num_kv_heads, head_dim, -1) + + # Replicate using repeat_interleave (correct GQA pattern) + # This creates [h0, h0, ..., h1, h1, ...] pattern + weight_replicated = weight_reshaped.repeat_interleave(repeat_factor, dim=0) + + # Reshape back to [num_attention_heads * head_dim, hidden_size] + weight_replicated = weight_replicated.view(-1, weight_replicated.shape[-1]) + + if scale is None: + return weight_replicated, None + + # Replicate per-row scales + # Scale shape: [num_kv_heads * head_dim, 1] + # Reshape to [num_kv_heads, head_dim, 1] + scale_reshaped = scale.view(num_kv_heads, head_dim, -1) + + # Replicate scales + scale_replicated = scale_reshaped.repeat_interleave(repeat_factor, dim=0) + + # Reshape back to [num_attention_heads * head_dim, 1] + scale_replicated = scale_replicated.view(-1, scale_replicated.shape[-1]) + + return weight_replicated, scale_replicated + + +def process_mimo_v2_checkpoint( + hf_model_path: str, + save_path: str, + tp_degree: int = 32, + convert_to_mha: bool = True, +): + """ + Process MiMo-V2-Flash checkpoint for Neuron FP8 inference. + + Args: + hf_model_path: Path to HuggingFace MiMo-V2-Flash checkpoint + save_path: Path to save preprocessed checkpoint + tp_degree: Tensor parallelism degree + convert_to_mha: Whether to replicate K/V for CONVERT_TO_MHA mode + """ + print(f"Loading checkpoint from: {hf_model_path}", flush=True) + state_dict = load_state_dict(hf_model_path) + + # Load config + config_path = os.path.join(hf_model_path, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + # Extract model dimensions + num_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_kv_heads = config["num_key_value_heads"] # Full attention: 4 + swa_num_attention_heads = config["swa_num_attention_heads"] # Sliding window: 32 + swa_num_kv_heads = config["swa_num_key_value_heads"] # Sliding window: 8 + head_dim = config["head_dim"] # Q/K head dim: 192 + v_head_dim = config["v_head_dim"] # V head dim: 128 + swa_head_dim = config.get("swa_head_dim", head_dim) + swa_v_head_dim = config.get("swa_v_head_dim", v_head_dim) + + # Get hybrid layer pattern + hybrid_layer_pattern = config.get("hybrid_layer_pattern", [0] * num_layers) + + # MoE configuration + num_experts = config["n_routed_experts"] # 256 + moe_intermediate_size = config["moe_intermediate_size"] + moe_layer_freq = config.get("moe_layer_freq", [1] * num_layers) + + # Block size for quantization + quant_config = config.get("quantization_config", {}) + block_size = quant_config.get("weight_block_size", [128, 128]) + + print(f"\nModel configuration:", flush=True) + print(f" num_layers: {num_layers}", flush=True) + print(f" hidden_size: {hidden_size}", flush=True) + print(f" num_attention_heads: {num_attention_heads}", flush=True) + print(f" num_kv_heads (full): {num_kv_heads}", flush=True) + print(f" swa_num_kv_heads (sliding): {swa_num_kv_heads}", flush=True) + print(f" head_dim (Q/K): {head_dim}", flush=True) + print(f" v_head_dim: {v_head_dim}", flush=True) + print(f" num_experts: {num_experts}", flush=True) + print(f" moe_intermediate_size: {moe_intermediate_size}", flush=True) + print(f" block_size: {block_size}", flush=True) + print(f" tp_degree: {tp_degree}", flush=True) + print(f" convert_to_mha: {convert_to_mha}", flush=True) + + state_dict_keys = set(state_dict.keys()) + new_state_dict = {} + + # Process each layer + for layer_idx in range(num_layers): + print(f"\nProcessing layer {layer_idx}...", end="", flush=True) + + prefix = f"model.layers.{layer_idx}." + is_sliding_window = hybrid_layer_pattern[layer_idx] == 1 + is_moe_layer = moe_layer_freq[layer_idx] == 1 + + # Get layer-specific parameters + if is_sliding_window: + layer_num_heads = swa_num_attention_heads + layer_num_kv_heads = swa_num_kv_heads + layer_head_dim = swa_head_dim + layer_v_head_dim = swa_v_head_dim + else: + layer_num_heads = num_attention_heads + layer_num_kv_heads = num_kv_heads + layer_head_dim = head_dim + layer_v_head_dim = v_head_dim + + attn_type = "sliding_window" if is_sliding_window else "full" + print(f" ({attn_type}, kv_heads={layer_num_kv_heads})", end="", flush=True) + + # Process attention weights + for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]: + weight_key = f"{prefix}self_attn.{proj}.weight" + scale_key = f"{prefix}self_attn.{proj}.weight_scale_inv" + + if weight_key not in state_dict_keys: + continue + + weight = state_dict[weight_key] + scale = state_dict.get(scale_key) + + # Handle FP8 weights - convert to per-row scaling for Neuron + # Neuron framework expects per-row (per-channel) scaling for attention layers + if weight.dtype == torch.float8_e4m3fn and scale is not None: + weight, scale = rescale_fp8_to_per_row(weight, scale) + # Handle BF16 weights (convert to FP8 with per-row scales) + elif weight.dtype == torch.bfloat16: + weight, scale = convert_bf16_to_fp8_per_row(weight) + + # NOTE: Do NOT apply CONVERT_TO_MHA replication here. + # The Neuron framework handles K/V replication internally. + # Pre-replicating would cause double-replication. + + # Save with Neuron naming convention + new_weight_key = f"layers.{layer_idx}.self_attn.{proj}.weight" + new_state_dict[new_weight_key] = weight + + if scale is not None: + new_scale_key = f"layers.{layer_idx}.self_attn.{proj}.scale" + new_state_dict[new_scale_key] = scale + + # Process layer norms (no FP8) + for norm in ["input_layernorm", "post_attention_layernorm"]: + weight_key = f"{prefix}{norm}.weight" + if weight_key in state_dict_keys: + new_key = f"layers.{layer_idx}.{norm}.weight" + new_state_dict[new_key] = state_dict[weight_key] + + # Process MoE router + router_key = f"{prefix}mlp.gate.weight" + if router_key in state_dict_keys: + new_key = f"layers.{layer_idx}.mlp.router.linear_router.weight" + new_state_dict[new_key] = state_dict[router_key] + + # Process MoE experts + if is_moe_layer: + # Prepare fused gate_up and down projections + gate_weights = [] + gate_scales = [] + up_weights = [] + up_scales = [] + down_weights = [] + down_scales = [] + + for expert_idx in range(num_experts): + expert_prefix = f"{prefix}mlp.experts.{expert_idx}." + + # Gate projection + gate_w_key = f"{expert_prefix}gate_proj.weight" + gate_s_key = f"{expert_prefix}gate_proj.weight_scale_inv" + + if gate_w_key in state_dict_keys: + gate_w = state_dict[gate_w_key] + gate_s = state_dict.get(gate_s_key) + + if gate_w.dtype == torch.float8_e4m3fn and gate_s is not None: + gate_w, gate_s = rescale_fp8_weight_blockwise(gate_w, gate_s) + elif gate_w.dtype == torch.bfloat16: + gate_w, gate_s = convert_bf16_to_fp8_blockwise(gate_w, block_size) + + gate_weights.append(gate_w.T) # Transpose for fusion + if gate_s is not None: + gate_scales.append(gate_s) + + # Up projection + up_w_key = f"{expert_prefix}up_proj.weight" + up_s_key = f"{expert_prefix}up_proj.weight_scale_inv" + + if up_w_key in state_dict_keys: + up_w = state_dict[up_w_key] + up_s = state_dict.get(up_s_key) + + if up_w.dtype == torch.float8_e4m3fn and up_s is not None: + up_w, up_s = rescale_fp8_weight_blockwise(up_w, up_s) + elif up_w.dtype == torch.bfloat16: + up_w, up_s = convert_bf16_to_fp8_blockwise(up_w, block_size) + + up_weights.append(up_w.T) # Transpose for fusion + if up_s is not None: + up_scales.append(up_s) + + # Down projection + down_w_key = f"{expert_prefix}down_proj.weight" + down_s_key = f"{expert_prefix}down_proj.weight_scale_inv" + + if down_w_key in state_dict_keys: + down_w = state_dict[down_w_key] + down_s = state_dict.get(down_s_key) + + if down_w.dtype == torch.float8_e4m3fn and down_s is not None: + down_w, down_s = rescale_fp8_weight_blockwise(down_w, down_s) + elif down_w.dtype == torch.bfloat16: + down_w, down_s = convert_bf16_to_fp8_blockwise(down_w, block_size) + + down_weights.append(down_w.T) # Transpose for fusion + if down_s is not None: + down_scales.append(down_s) + + # Fuse gate and up projections + if gate_weights and up_weights: + # Stack experts: [num_experts, hidden_size, intermediate_size] + gate_stacked = torch.stack(gate_weights, dim=0) + up_stacked = torch.stack(up_weights, dim=0) + + # Concatenate gate and up: [num_experts, hidden_size, 2 * intermediate_size] + gate_up_fused = torch.cat([gate_stacked, up_stacked], dim=2) + + new_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + new_state_dict[new_key] = gate_up_fused + + # Fuse scales if present + if gate_scales and up_scales: + # Scales shape after transpose: [scale_h, scale_w] + # After stacking: [num_experts, scale_h, scale_w] + gate_s_stacked = torch.stack(gate_scales, dim=0) + up_s_stacked = torch.stack(up_scales, dim=0) + + # Concatenate scales along last dim + gate_up_scale = torch.cat([gate_s_stacked, up_s_stacked], dim=-1) + + new_scale_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + new_state_dict[new_scale_key] = gate_up_scale + + # Down projection + if down_weights: + # Stack: [num_experts, intermediate_size, hidden_size] + down_stacked = torch.stack(down_weights, dim=0) + + new_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight" + new_state_dict[new_key] = down_stacked + + if down_scales: + down_s_stacked = torch.stack(down_scales, dim=0) + new_scale_key = f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.scale" + new_state_dict[new_scale_key] = down_s_stacked + else: + # Non-MoE layer: regular MLP with gate_proj, up_proj, down_proj + for proj in ["gate_proj", "up_proj", "down_proj"]: + weight_key = f"{prefix}mlp.{proj}.weight" + scale_key = f"{prefix}mlp.{proj}.weight_scale_inv" + + if weight_key not in state_dict_keys: + continue + + weight = state_dict[weight_key] + scale = state_dict.get(scale_key) + + # Handle FP8 weights - convert to per-row scaling for Neuron + if weight.dtype == torch.float8_e4m3fn and scale is not None: + weight, scale = rescale_fp8_to_per_row(weight, scale) + # Handle BF16 weights (convert to FP8 with per-row scales) + elif weight.dtype == torch.bfloat16: + weight, scale = convert_bf16_to_fp8_per_row(weight) + + # Save with Neuron naming convention + new_weight_key = f"layers.{layer_idx}.mlp.{proj}.weight" + new_state_dict[new_weight_key] = weight + + if scale is not None: + new_scale_key = f"layers.{layer_idx}.mlp.{proj}.scale" + new_state_dict[new_scale_key] = scale + + gc.collect() + print(" done", flush=True) + + # Process embeddings and final layer norm + print("\nProcessing embeddings and final norm...", flush=True) + + if "model.embed_tokens.weight" in state_dict_keys: + new_state_dict["embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] + + if "model.norm.weight" in state_dict_keys: + new_state_dict["norm.weight"] = state_dict["model.norm.weight"] + + if "lm_head.weight" in state_dict_keys: + new_state_dict["lm_head.weight"] = state_dict["lm_head.weight"] + elif "model.embed_tokens.weight" in state_dict_keys: + # Tied embeddings + new_state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] + + # Save preprocessed checkpoint + print(f"\nSaving preprocessed checkpoint to: {save_path}", flush=True) + os.makedirs(save_path, exist_ok=True) + + save_state_dict_safetensors(new_state_dict, save_path) + + # Copy config.json + import shutil + shutil.copy(config_path, os.path.join(save_path, "config.json")) + + # Copy tokenizer files + for tokenizer_file in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: + src_path = os.path.join(hf_model_path, tokenizer_file) + if os.path.exists(src_path): + shutil.copy(src_path, os.path.join(save_path, tokenizer_file)) + + print(f"\nPreprocessing complete!", flush=True) + print(f" Total parameters: {len(new_state_dict)}", flush=True) + + # Print FP8 weight count + fp8_count = sum(1 for v in new_state_dict.values() if v.dtype == torch.float8_e4m3fn) + scale_count = sum(1 for k in new_state_dict.keys() if k.endswith(".scale")) + print(f" FP8 weights: {fp8_count}", flush=True) + print(f" Scale parameters: {scale_count}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Preprocess MiMo-V2-Flash FP8 checkpoint for Neuron inference" + ) + parser.add_argument( + "--hf_model_path", + type=str, + required=True, + help="Path to HuggingFace MiMo-V2-Flash checkpoint", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to save preprocessed checkpoint", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=32, + help="Tensor parallelism degree (default: 32)", + ) + parser.add_argument( + "--convert_to_mha", + action="store_true", + default=True, + help="Replicate K/V for CONVERT_TO_MHA mode (default: True)", + ) + parser.add_argument( + "--no_convert_to_mha", + action="store_false", + dest="convert_to_mha", + help="Disable K/V replication", + ) + + args = parser.parse_args() + + process_mimo_v2_checkpoint( + hf_model_path=args.hf_model_path, + save_path=args.save_path, + tp_degree=args.tp_degree, + convert_to_mha=args.convert_to_mha, + ) + + +if __name__ == "__main__": + main() diff --git a/src/neuronx_distributed_inference/models/mimo_v2/modeling_mimo_v2.py b/src/neuronx_distributed_inference/models/mimo_v2/modeling_mimo_v2.py new file mode 100644 index 00000000..4ea33fb0 --- /dev/null +++ b/src/neuronx_distributed_inference/models/mimo_v2/modeling_mimo_v2.py @@ -0,0 +1,1333 @@ +# coding=utf-8 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# This implementation is based on the MiMo-V2-Flash model from Xiaomi. +# Reference: https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash + +"""MiMo-V2-Flash model for NXD inference.""" + +import gc +import math +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.utils.distributed import ( + split_along_dim, + get_cp_rank, +) +from neuronx_distributed_inference.modules.attention.attention_process_groups import ( + get_context_parallel_attention_cp_group, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from torch_neuronx.xla_impl.ops import nki_jit + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +def get_rmsnorm_cls(): + """Get appropriate RMSNorm class based on execution environment.""" + return MiMoV2RMSNorm if cpu_mode() else CustomRMSNorm + + +class MiMoV2RMSNorm(nn.Module): + """RMSNorm implementation for CPU mode.""" + + def __init__(self, hidden_size: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MiMoV2RotaryEmbedding(nn.Module): + """Rotary Position Embedding for MiMo-V2-Flash. + + Supports partial rotary embedding where only a fraction of dimensions + use rotary position encoding. + """ + + def __init__( + self, + dim: int, + max_position_embeddings: int = 262144, + base: float = 5000000.0, + partial_rotary_factor: float = 1.0, + ): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.partial_rotary_factor = partial_rotary_factor + + # Calculate the actual dimension used for rotary embedding + self.rope_dim = int(dim * partial_rotary_factor) + # Ensure rope_dim is even + self.rope_dim = self.rope_dim - (self.rope_dim % 2) + + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.rope_dim, 2, dtype=torch.float32) / self.rope_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute rotary embeddings. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size) + position_ids: Position indices of shape (batch_size, seq_len) + + Returns: + Tuple of (cos, sin) tensors for rotary embedding + """ + inv_freq_expanded = self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1 + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding to query and key tensors.""" + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MiMoV2InferenceConfig(InferenceConfig): + """Configuration class for MiMo-V2-Flash inference on Neuron.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # MoE configuration + self.num_local_experts = self.n_routed_experts + self.n_shared_experts = 0 # MiMo-V2-Flash has no shared experts + + # Set intermediate_size for MoE layers + self.intermediate_size = self.moe_intermediate_size + + # Check and pad intermediate size if needed + self.maybe_pad_intermediate() + + # Router configuration + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" # MiMo uses sigmoid + + # Disable numeric CC token as workaround + self.neuron_config.disable_numeric_cc_token = True + + # MiMo normalizes top-k affinities + self.neuron_config.normalize_top_k_affinities = True + + # Parse hybrid layer pattern + self._parse_hybrid_pattern() + + def _parse_hybrid_pattern(self): + """Parse hybrid layer pattern to determine attention types.""" + if hasattr(self, 'hybrid_layer_pattern') and self.hybrid_layer_pattern: + self.layer_attention_types = [ + "sliding_window" if p == 1 else "full" + for p in self.hybrid_layer_pattern + ] + else: + self.layer_attention_types = ["full"] * self.num_hidden_layers + + # Parse MoE layer frequency + if hasattr(self, 'moe_layer_freq') and self.moe_layer_freq: + self.layer_uses_moe = [bool(f) for f in self.moe_layer_freq] + else: + self.layer_uses_moe = [True] * self.num_hidden_layers + + def maybe_pad_intermediate(self): + """Pad intermediate size if required for efficient computation.""" + from neuronx_distributed_inference.models.config import ( + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + ) + + moe_tp_degree = self.neuron_config.moe_tp_degree + I_TP = self.moe_intermediate_size // moe_tp_degree + + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if I_TP % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded_size = ( + math.ceil(I_TP / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max( + padded_size - self.moe_intermediate_size, 0 + ) + self.moe_intermediate_size = padded_size + + def get_required_attributes(self) -> List[str]: + return [ + "attention_bias", + "head_dim", + "hidden_act", + "hidden_size", + "hybrid_layer_pattern", + "layernorm_epsilon", + "max_position_embeddings", + "moe_intermediate_size", + "moe_layer_freq", + "n_routed_experts", + "norm_topk_prob", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "partial_rotary_factor", + "rope_theta", + "scoring_func", + "sliding_window", + "swa_head_dim", + "swa_num_attention_heads", + "swa_num_key_value_heads", + "swa_rope_theta", + "swa_v_head_dim", + "tie_word_embeddings", + "v_head_dim", + "vocab_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[MoENeuronConfig]: + return MoENeuronConfig + + +class NeuronMiMoV2Attention(NeuronAttentionBase): + """MiMo-V2-Flash Attention implementation supporting hybrid attention patterns. + + Supports both full attention and sliding window attention with different + head dimensions for Q/K vs V. + """ + + def __init__( + self, + config: MiMoV2InferenceConfig, + layer_idx: int, + is_sliding_window: bool = False, + ): + self.layer_idx = layer_idx + self.is_sliding_window = is_sliding_window + + # Select parameters based on attention type + if is_sliding_window: + self.attn_head_dim = config.swa_head_dim + self.attn_v_head_dim = config.swa_v_head_dim + self.attn_num_heads = config.swa_num_attention_heads + self.attn_num_kv_heads = config.swa_num_key_value_heads + rope_theta = getattr(config, 'swa_rope_theta', 10000.0) + self.sliding_window_size = config.sliding_window + else: + self.attn_head_dim = config.head_dim + self.attn_v_head_dim = config.v_head_dim + self.attn_num_heads = config.num_attention_heads + self.attn_num_kv_heads = config.num_key_value_heads + rope_theta = config.rope_theta + self.sliding_window_size = None + + # Calculate partial rotary dimensions + self.partial_rotary_factor = config.partial_rotary_factor + self.rope_dim = int(self.attn_head_dim * self.partial_rotary_factor) + self.rope_dim = self.rope_dim - (self.rope_dim % 2) # Ensure even + self.nope_dim = self.attn_head_dim - self.rope_dim + + # Create rotary embedding + rotary_emb = MiMoV2RotaryEmbedding( + dim=self.attn_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + ) + + # Initialize base attention + # NOTE: We pass v_head_dim to base class, but MiMo uses asymmetric Q/K (192) vs V (128). + # We override init_gqa_properties() to prevent the base class from creating + # incompatible projection layers (which cause crashes when CP > 1). + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=self.attn_num_heads, + num_key_value_heads=self.attn_num_kv_heads, + head_dim=self.attn_v_head_dim, # Use v_head_dim for base class + rotary_emb=rotary_emb, + rms_norm_eps=config.layernorm_epsilon, + use_qk_norm=False, + ) + + # Initialize MiMo-specific projections with correct dimensions + self._init_projections(config) + + # Scaling factor + self.scaling = self.attn_head_dim ** -0.5 + # NOTE: The config may have 'attention_value_scale' (e.g., 0.707), but the HF model + # (modeling_mimo_v2_flash.py) does NOT use this value. The HF model only uses + # head_dim ** -0.5 for attention scaling, which is already applied via self.scaling. + # We must NOT apply attention_value_scale here, as it would cause divergence from HF. + self.value_scale = 1.0 + + # Store cache KV heads for cache compatibility + # With CONVERT_TO_MHA, all layers have num_attention_heads KV heads + # Otherwise, use max of full and sliding window kv heads + tp_degree = config.neuron_config.tp_degree + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: cache stores num_attention_heads (same as Q heads) + self.cache_num_kv_heads = self.attn_num_heads + self.local_cache_kv_heads = self.local_num_heads + else: + # Standard GQA: cache uses max of full and sliding window kv heads + self.cache_num_kv_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.local_cache_kv_heads = max(1, self.cache_num_kv_heads // tp_degree) + + def init_gqa_properties(self): + """Override base class to prevent creating incompatible QKV projections. + + MiMo-V2-Flash has asymmetric Q/K head_dim (192) vs V head_dim (128), + which is incompatible with the base class's GroupQueryAttention_QKV. + MiMo uses its own custom projections via _init_projections() instead. + + When CP > 1, the base class would create cte_qkv_proj/tkg_qkv_proj with + wrong head_dim=128, causing compilation crashes. This no-op prevents that. + """ + pass + + def _init_projections(self, config: MiMoV2InferenceConfig): + """Initialize projection layers with correct dimensions. + + When CONVERT_TO_MHA is needed (tp_degree > num_kv_heads), K/V projections + are sized for num_attention_heads (not original num_kv_heads). The checkpoint + weights are replicated in preshard_hook before loading. + """ + dtype = config.neuron_config.torch_dtype + tp_degree = config.neuron_config.tp_degree + + # Check if we need GQA CONVERT_TO_MHA (when tp_degree > num_kv_heads) + self.use_gqa_convert_to_mha = tp_degree > self.attn_num_kv_heads + + # Store source heads for preshard_hook + self._src_num_kv_heads = self.attn_num_kv_heads + self._kv_replication_factor = self.attn_num_heads // self.attn_num_kv_heads if self.use_gqa_convert_to_mha else 1 + + if self.use_gqa_convert_to_mha: + # CONVERT_TO_MHA: K and V use num_attention_heads for proper TP splitting + k_num_heads = self.attn_num_heads + v_num_heads = self.attn_num_heads + else: + k_num_heads = self.attn_num_kv_heads + v_num_heads = self.attn_num_kv_heads + + # Q/K use head_dim, V uses v_head_dim + q_hidden_size = self.attn_num_heads * self.attn_head_dim + k_hidden_size = k_num_heads * self.attn_head_dim + v_hidden_size = v_num_heads * self.attn_v_head_dim + o_hidden_size = self.attn_num_heads * self.attn_v_head_dim + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Q projection + self.q_proj = ColumnParallelLinear( + config.hidden_size, + q_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # K projection + self.k_proj = ColumnParallelLinear( + config.hidden_size, + k_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # V projection + self.v_proj = ColumnParallelLinear( + config.hidden_size, + v_hidden_size, + bias=config.attention_bias, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + + # Output projection - with sequence parallel to scatter output + self.o_proj = RowParallelLinear( + o_hidden_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=1 if self.sequence_parallel_enabled else None, + ) + + # Calculate local dimensions after TP split + self.local_num_heads = self.attn_num_heads // tp_degree + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, local KV heads = local Q heads + self.local_num_kv_heads = self.local_num_heads + else: + self.local_num_kv_heads = max(1, self.attn_num_kv_heads // tp_degree) + else: + self.q_proj = nn.Linear(config.hidden_size, q_hidden_size, bias=config.attention_bias) + self.k_proj = nn.Linear(config.hidden_size, k_hidden_size, bias=config.attention_bias) + self.v_proj = nn.Linear(config.hidden_size, v_hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(o_hidden_size, config.hidden_size, bias=False) + + self.local_num_heads = self.attn_num_heads + self.local_num_kv_heads = k_num_heads + + # Override base class attributes that were computed with wrong head_dim + # The base class init_gqa_properties() uses head_dim=v_head_dim which is wrong for Q/K + # We need to override these to ensure correct computation + self.num_heads = self.local_num_heads + self.num_key_value_heads = self.local_num_kv_heads + self.num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + self.head_dim = self.attn_head_dim # Override to use actual Q/K head_dim (192) + + # Remove qkv_proj from base class if exists (we use separate q_proj, k_proj, v_proj) + if hasattr(self, 'qkv_proj'): + self.qkv_proj = None + + # Attention sink bias for attention layers (following HF implementation) + # This is a learnable parameter that allows attention to "sink" to an extra position + add_full_attention_sink_bias = getattr(config, 'add_full_attention_sink_bias', False) + add_swa_attention_sink_bias = getattr(config, 'add_swa_attention_sink_bias', True) + + # Determine if this layer uses sink bias based on config + self._use_sink_bias = (add_full_attention_sink_bias and not self.is_sliding_window) or \ + (add_swa_attention_sink_bias and self.is_sliding_window) + + if self._use_sink_bias: + # Shape: [num_attention_heads] - will be split across TP ranks + # The weight is loaded from checkpoint with shape [num_attention_heads] + # and will be sliced to [local_num_heads] during forward + self.attention_sink_bias = nn.Parameter( + torch.zeros(self.attn_num_heads, dtype=dtype), requires_grad=False + ) + else: + self.attention_sink_bias = None + + def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: + """Pre-shard hook to replicate K/V weights for CONVERT_TO_MHA. + + NOTE: This method is NOT currently called because NeuronMiMoV2Attention + is not a BaseGroupQueryAttention subclass. K/V weight replication is + instead done in convert_mimo_v2_hf_to_neuron_state_dict(). + + This method is kept for reference and potential future use. + """ + # This hook is not called - see note above + return False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """Forward pass for MiMo-V2-Flash attention with Context Parallelism support.""" + + # Context Parallelism: only active during context encoding (no past_key_value) + is_context_parallel = past_key_value is None and self.cp_degree > 1 + cp_rank = None + + if is_context_parallel: + cp_rank = get_cp_rank( + self.rank_util.get_rank(), self.tp_degree, + self.cp_degree, self.neuron_config.switch_cc, + ) + # Split attention_mask (dim=2 = Q rows) and position_ids (dim=1 = seq) + attention_mask = split_along_dim( + attention_mask, dim=2, rank=cp_rank, num_partitions=self.cp_degree + ) + # Keep full position_ids for RoPE computation on full-length K/V + local_position_ids = split_along_dim( + position_ids, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + # Handle sequence parallel + if self.sequence_parallel_enabled and parallel_state.model_parallel_is_initialized(): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=parallel_state.get_tensor_model_parallel_group(), + ) + + # Context Parallelism without sequence parallel: split hidden_states + if is_context_parallel and not self.sequence_parallel_enabled: + hidden_states = split_along_dim( + hidden_states, dim=1, rank=cp_rank, num_partitions=self.cp_degree + ) + + bsz, q_len, _ = hidden_states.size() + + # Determine if this is token generation (past_key_value is not None) + is_token_gen = past_key_value is not None + + # Project Q, K, V + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Reshape for multi-head attention: [bsz, num_heads, seq_len, head_dim] + query_states = query_states.view(bsz, q_len, self.local_num_heads, self.attn_head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.local_num_kv_heads, self.attn_v_head_dim).transpose(1, 2) + + # Split into rope and non-rope parts + query_rope = query_states[..., :self.rope_dim] + query_nope = query_states[..., self.rope_dim:] + key_rope = key_states[..., :self.rope_dim] + key_nope = key_states[..., self.rope_dim:] + + # Compute rotary embeddings + # IMPORTANT: Always compute for this layer because different layer types + # (full vs sliding window) use different rope_theta values. + # Full attention: rope_theta = 5000000 + # Sliding window: rope_theta = 10000 + # We cannot reuse cached cos/sin from other layers! + # + # For CP with sequence_parallel: Q/K/V have full S, use full position_ids for RoPE. + # For CP without sequence_parallel: Q/K/V have S/CP, use local_position_ids for RoPE + # (local_position_ids contain the correct global positions for this CP rank). + if is_context_parallel and not self.sequence_parallel_enabled: + rope_position_ids = local_position_ids + else: + rope_position_ids = position_ids + cos_cache, sin_cache = self.rotary_emb(value_states, rope_position_ids) + + # Apply rotary position embedding to rope parts only + query_rope, key_rope = apply_rotary_pos_emb( + query_rope, key_rope, cos_cache, sin_cache, rope_position_ids + ) + + # Concatenate rope and non-rope parts + query_states = torch.cat([query_rope, query_nope], dim=-1) + key_states = torch.cat([key_rope, key_nope], dim=-1) + + # Context Parallelism: split Q and save local KV for cache + if is_context_parallel: + if self.sequence_parallel_enabled: + # Q/K/V have full S. Split Q to local portion, save local KV for cache. + # Use split_along_dim (torch.index_select) instead of Python slicing + # because XLA tracing doesn't support dynamic tensor indices in slice notation. + query_states = split_along_dim(query_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + key_states_for_cache = split_along_dim(key_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + value_states_for_cache = split_along_dim(value_states, dim=2, rank=cp_rank, num_partitions=self.cp_degree) + q_len = q_len // self.cp_degree + # K/V stay at full S for attention computation + else: + # Q/K/V have S/CP. Save local KV for cache, then all-gather K/V. + key_states_for_cache = key_states + value_states_for_cache = value_states + key_states = gather_from_tensor_model_parallel_region_with_dim( + key_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + value_states = gather_from_tensor_model_parallel_region_with_dim( + value_states, gather_dim=2, + process_group=get_context_parallel_attention_cp_group(), + ) + # Q stays at S/CP + else: + # Store key/value states BEFORE GQA repeat for KV cache + key_states_for_cache = key_states + value_states_for_cache = value_states + + # WORKAROUND 1: Pad V from v_head_dim (128) to head_dim (192) for KV cache compatibility + if self.attn_v_head_dim < self.attn_head_dim: + pad_size = self.attn_head_dim - self.attn_v_head_dim + value_states_for_cache = F.pad(value_states_for_cache, (0, pad_size), value=0.0) + + # WORKAROUND 2: Pad KV heads if layer has fewer than cache expects + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Pad KV heads by repeating + repeat_factor = self.local_cache_kv_heads // self.local_num_kv_heads + key_states_for_cache = key_states_for_cache.repeat(1, repeat_factor, 1, 1) + value_states_for_cache = value_states_for_cache.repeat(1, repeat_factor, 1, 1) + + # Repeat KV heads for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, K/V already have num_attention_heads + num_key_value_groups = self.local_num_heads // self.local_num_kv_heads + if num_key_value_groups > 1: + key_states = key_states.repeat_interleave(num_key_value_groups, dim=1) + value_states = value_states.repeat_interleave(num_key_value_groups, dim=1) + + if is_token_gen: + # Token generation: use decomposed attention with prior (cached) and active (current) KV + # past_key_value[0] = cached K, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] + # past_key_value[1] = cached V, shape [bsz, cache_kv_heads, kv_seq_len, head_dim] (padded) + K_prior = past_key_value[0] + V_prior = past_key_value[1] + + # WORKAROUND 1: Slice KV heads if cache has more than layer needs + # Only needed when NOT using CONVERT_TO_MHA (standard GQA mode) + # With CONVERT_TO_MHA, cache and layer have same num_kv_heads + if not self.use_gqa_convert_to_mha and self.local_num_kv_heads < self.local_cache_kv_heads: + # Cache has repeated heads, just take the first local_num_kv_heads + K_prior = K_prior[:, :self.local_num_kv_heads, :, :] + V_prior = V_prior[:, :self.local_num_kv_heads, :, :] + + # WORKAROUND 2: Slice V_prior back to v_head_dim (128) from head_dim (192) + if self.attn_v_head_dim < self.attn_head_dim: + V_prior = V_prior[..., :self.attn_v_head_dim] + + # Repeat cached KV for GQA (only needed without CONVERT_TO_MHA) + # With CONVERT_TO_MHA, cached K/V already have num_attention_heads + if num_key_value_groups > 1: + K_prior = K_prior.repeat_interleave(num_key_value_groups, dim=1) + V_prior = V_prior.repeat_interleave(num_key_value_groups, dim=1) + + # Compute attention on prior (cached) KV + # K_prior shape: [bsz, num_heads, kv_seq_len, head_dim] + prior_scores = torch.matmul(query_states, K_prior.transpose(-2, -1)) * self.scaling + + # Apply attention mask to prior scores + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + prior_scores = prior_scores.masked_fill(~attention_mask, float('-inf')) + else: + prior_scores = prior_scores + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None and position_ids is not None: + kv_seq_len = prior_scores.size(-1) + current_pos = position_ids[0, 0] + pos_indices = torch.arange(kv_seq_len, device=prior_scores.device) + sliding_mask = pos_indices >= (current_pos - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, None, :] + prior_scores = prior_scores.masked_fill(~sliding_mask, float('-inf')) + + prior_scores = prior_scores.to(torch.float32) + + # Compute attention on active (current) KV + active_scores = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + active_scores = active_scores.to(torch.float32) + + # Combined softmax over prior and active scores + all_scores = torch.cat([prior_scores, active_scores], dim=-1) + + # Add attention sink bias (following HF implementation) + # This must be applied to token generation as well! + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + all_scores = torch.cat([all_scores, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax + all_scores = all_scores - all_scores.max(dim=-1, keepdim=True).values + attn_weights = F.softmax(all_scores, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + # Split attention weights back + prior_weights = attn_weights[..., :-q_len].to(V_prior.dtype) + active_weights = attn_weights[..., -q_len:].to(value_states.dtype) + + # Compute attention outputs + attn_prior = torch.matmul(prior_weights, V_prior) + attn_active = torch.matmul(active_weights, value_states) + attn_output = attn_prior + attn_active + else: + # Context encoding: standard attention + # With CP: Q is local [B, H, S/CP, D], K/V are full [B, H, S, D] + # Without CP: Q/K/V all have same seq_len + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.scaling + + # Apply attention mask (additive mask: 0 = attend, -inf = mask out) + # The framework creates boolean masks, so we need to convert them + # With CP: attention_mask is already split to [B, 1, S/CP, S] (local Q rows, full K cols) + if attention_mask is not None: + # Convert boolean mask to additive mask if needed + if attention_mask.dtype == torch.bool: + # True = attend (0), False = mask (-inf) + additive_mask = torch.zeros_like(attn_weights) + additive_mask = additive_mask.masked_fill(~attention_mask, float('-inf')) + attn_weights = attn_weights + additive_mask + else: + # Already additive mask + attn_weights = attn_weights + attention_mask + + # Apply sliding window mask for SWA layers + if self.is_sliding_window and self.sliding_window_size is not None: + kv_seq_len = attn_weights.size(-1) + if is_context_parallel: + # With CP: Q has local seq len, K has full seq len. + # Use local_position_ids for correct global Q positions. + row_idx = local_position_ids[0].unsqueeze(1).to(attn_weights.device) + else: + row_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(1) + col_idx = torch.arange(kv_seq_len, device=attn_weights.device).unsqueeze(0) + # Causal: col <= row, and within window: col >= row - window_size + 1 + sliding_mask = (col_idx <= row_idx) & (col_idx >= row_idx - self.sliding_window_size + 1) + sliding_mask = sliding_mask[None, None, :, :] + # Convert to additive mask + attn_weights = attn_weights.masked_fill(~sliding_mask, float('-inf')) + + # Add attention sink bias (following HF implementation) + # This adds an extra "sink" column to attention weights + use_sink = self._use_sink_bias and self.attention_sink_bias is not None + if use_sink: + # Get local portion of sink bias for this TP rank + tp_rank = parallel_state.get_tensor_model_parallel_rank() if parallel_state.model_parallel_is_initialized() else 0 + local_sink = self.attention_sink_bias[tp_rank * self.local_num_heads:(tp_rank + 1) * self.local_num_heads] + # Reshape and expand: [local_num_heads] -> [bsz, local_num_heads, q_len, 1] + sink_bias = local_sink.reshape(1, -1, 1, 1).expand(bsz, -1, q_len, 1) + attn_weights = torch.cat([attn_weights, sink_bias], dim=-1) + + # Numerical stability: subtract max before softmax (like HF implementation) + attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values + + # Softmax + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32) + + # Drop the sink column after softmax + if use_sink: + attn_weights = attn_weights[..., :-1] + + attn_weights = attn_weights.to(value_states.dtype) + + # Apply attention to values + attn_output = torch.matmul(attn_weights, value_states) + + # Apply value scale if specified + if self.value_scale != 1.0: + attn_output = attn_output * self.value_scale + + # Reshape and project output + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.local_num_heads * self.attn_v_head_dim) + + # Context Parallelism: gather output across CP ranks BEFORE o_proj. + # With SP enabled, o_proj scatters along seq dim. The input must have full S + # (not S/CP), otherwise the SP-scattered output won't match the residual. + # Without SP, gather after o_proj to restore full seq_len for residual. + if is_context_parallel: + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + attn_output = self.o_proj(attn_output) + + # Prepare KV cache output - return as tuple for KV cache manager + # Return LOCAL key/value states for cache (each CP rank stores its portion) + new_key_value = (key_states_for_cache, value_states_for_cache) + + return attn_output, new_key_value, cos_cache, sin_cache + + +class MiMoV2MLP(nn.Module): + """Standard MLP for non-MoE layers in MiMo-V2-Flash.""" + + def __init__(self, config: MiMoV2InferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + # Use the dense intermediate size for non-MoE layers + self.intermediate_size = getattr(config, 'dense_intermediate_size', config.intermediate_size * 8) + + dtype = config.neuron_config.torch_dtype + + if parallel_state.model_parallel_is_initialized(): + tp_group = parallel_state.get_tensor_model_parallel_group() + + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + tensor_model_parallel_group=tp_group, + ) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + self.act_fn = F.silu + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronMiMoV2DecoderLayer(nn.Module): + """MiMo-V2-Flash Decoder Layer with hybrid attention and conditional MoE.""" + + def __init__(self, config: MiMoV2InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + self.attention_type = "sliding_window" if is_sliding_window else "full" + + # Create attention module + self.self_attn = NeuronMiMoV2Attention( + config=config, + layer_idx=layer_idx, + is_sliding_window=is_sliding_window, + ) + + # Determine if this layer uses MoE + self.uses_moe = config.layer_uses_moe[layer_idx] + + # Create MLP/MoE module + if self.uses_moe: + self.mlp = initialize_moe_module(config=config) + else: + self.mlp = MiMoV2MLP(config) + + # Layer norms + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + # Config flags + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + """Forward pass for decoder layer.""" + + # Self attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MLP/MoE with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.uses_moe: + hidden_states = self.mlp(hidden_states, padding_mask)[0] + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +class NeuronMiMoV2Model(NeuronBaseModel): + """MiMo-V2-Flash Model for NXD inference.""" + + def setup_attr_for_model(self, config: MiMoV2InferenceConfig): + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # Check if we need GQA CONVERT_TO_MHA mode + # When tp_degree > num_kv_heads, we replicate K/V to match num_attention_heads + min_kv_heads = min( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + self.use_gqa_convert_to_mha = self.tp_degree > min_kv_heads + + if self.use_gqa_convert_to_mha: + # With CONVERT_TO_MHA, KV cache stores num_attention_heads (same as Q) + self.num_key_value_heads = config.num_attention_heads + else: + # Standard GQA: use the maximum num_kv_heads for KV cache + # (handles hybrid full/sliding window attention) + self.num_key_value_heads = max( + config.num_key_value_heads, + getattr(config, 'swa_num_key_value_heads', config.num_key_value_heads) + ) + + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + # MiMo has hybrid attention (full + sliding window) + # NOTE: Do NOT set self.sliding_window here because it affects KV cache size globally. + # MiMo handles sliding window per-layer in the attention module itself. + # Setting has_mixed_attn = True enables proper mask creation without affecting cache size. + self.has_mixed_attn = True + + def init_model(self, config: MiMoV2InferenceConfig): + self.padding_idx = getattr(config, 'pad_token_id', None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList([ + NeuronMiMoV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.norm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.layernorm_epsilon, + ) + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +def _replicate_kv_weights_for_convert_to_mha( + tensor: torch.Tensor, + source_heads: int, + target_heads: int, + head_dim: int, +) -> torch.Tensor: + """Replicate K/V weights from source_heads to target_heads for CONVERT_TO_MHA. + + Args: + tensor: Weight tensor of shape [source_heads * head_dim, hidden_size] + source_heads: Number of source KV heads + target_heads: Number of target heads (num_attention_heads) + head_dim: Head dimension + + Returns: + Replicated tensor of shape [target_heads * head_dim, hidden_size] + """ + if tensor is None or source_heads >= target_heads: + return tensor + + repeats = target_heads // source_heads + + # Reshape to [source_heads, head_dim, hidden_size] + original_shape = tensor.shape + tensor = tensor.view(source_heads, head_dim, -1) + + # Repeat along head dimension + tensor = tensor.repeat_interleave(repeats, dim=0) + + # Reshape back to [num_heads * head_dim, hidden_size] + tensor = tensor.view(-1, original_shape[-1]) + + return tensor + + +def convert_mimo_v2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +) -> Dict[str, Any]: + """Convert HuggingFace MiMo-V2-Flash weights to Neuron format. + + This handles: + 1. Router weight renaming + 2. Expert weight concatenation and transposition + 3. FP8 dequantization if needed + 4. K/V weight replication for CONVERT_TO_MHA mode + """ + + assert config.neuron_config.glu_mlp is True, "Only GLU MLP is supported" + + # Dequantize layers if needed + _maybe_dequantize_layer(neuron_state_dict, config) + + # Add rank utility tensors + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine if CONVERT_TO_MHA is needed + tp_degree = config.neuron_config.tp_degree + num_attention_heads = config.num_attention_heads + + # MiMo-V2-Flash has different KV heads for full and sliding window attention + full_num_kv_heads = config.num_key_value_heads # 4 + swa_num_kv_heads = config.swa_num_key_value_heads # 8 + + # Check if we need to replicate K/V weights + full_use_convert_to_mha = tp_degree > full_num_kv_heads + swa_use_convert_to_mha = tp_degree > swa_num_kv_heads + + print(f"\n[DEBUG] CONVERT_TO_MHA status:") + print(f" tp_degree: {tp_degree}") + print(f" num_attention_heads: {num_attention_heads}") + print(f" full_num_kv_heads: {full_num_kv_heads}, use_convert_to_mha: {full_use_convert_to_mha}") + print(f" swa_num_kv_heads: {swa_num_kv_heads}, use_convert_to_mha: {swa_use_convert_to_mha}") + + for layer_idx in range(config.num_hidden_layers): + # Add rank utility for attention + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Determine attention type for this layer + is_sliding_window = config.layer_attention_types[layer_idx] == "sliding_window" + + if is_sliding_window: + src_num_kv_heads = swa_num_kv_heads + use_convert_to_mha = swa_use_convert_to_mha + head_dim = config.swa_head_dim # 192 + v_head_dim = config.swa_v_head_dim # 128 + else: + src_num_kv_heads = full_num_kv_heads + use_convert_to_mha = full_use_convert_to_mha + head_dim = config.head_dim # 192 + v_head_dim = config.v_head_dim # 128 + + # Replicate K/V weights if CONVERT_TO_MHA is needed + if use_convert_to_mha: + k_proj_key = f"layers.{layer_idx}.self_attn.k_proj.weight" + v_proj_key = f"layers.{layer_idx}.self_attn.v_proj.weight" + + if k_proj_key in neuron_state_dict: + old_shape = neuron_state_dict[k_proj_key].shape + neuron_state_dict[k_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[k_proj_key], + src_num_kv_heads, + num_attention_heads, + head_dim, + ) + print(f"[DEBUG] Layer {layer_idx} ({'SWA' if is_sliding_window else 'Full'}): Replicated K: {old_shape} -> {neuron_state_dict[k_proj_key].shape}") + + if v_proj_key in neuron_state_dict: + old_shape = neuron_state_dict[v_proj_key].shape + neuron_state_dict[v_proj_key] = _replicate_kv_weights_for_convert_to_mha( + neuron_state_dict[v_proj_key], + src_num_kv_heads, + num_attention_heads, + v_head_dim, + ) + print(f"[DEBUG] Layer {layer_idx} ({'SWA' if is_sliding_window else 'Full'}): Replicated V: {old_shape} -> {neuron_state_dict[v_proj_key].shape}") + + # Only convert MoE layers + if not config.layer_uses_moe[layer_idx]: + continue + + # Check if this layer has MoE weights + gate_key = f"layers.{layer_idx}.mlp.gate.weight" + if gate_key not in neuron_state_dict: + continue + + # Rename router weights + neuron_state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + neuron_state_dict[gate_key].detach().clone() + ) + del neuron_state_dict[gate_key] + + # Get dimensions from first expert + expert_0_gate = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_0_gate not in neuron_state_dict: + continue + + intermediate_size, hidden_size = neuron_state_dict[expert_0_gate].shape + device = neuron_state_dict[expert_0_gate].device + dtype = neuron_state_dict[expert_0_gate].dtype + + num_experts = config.n_routed_experts + + # Concatenate gate and up projections + gate_up_proj = torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + gate_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + ].T.detach().clone() + up_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" + ].T.detach().clone() + + gate_up_proj[e, :, :intermediate_size] = gate_proj_weights + gate_up_proj[e, :, intermediate_size:] = up_proj_weights + + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"] + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] + + # Pad if needed + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, 2, -1) + gate_up_proj = F.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape(num_experts, hidden_size, -1) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + + # Convert down projections + down_proj = torch.empty( + num_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + + for e in range(num_experts): + down_proj_weights = neuron_state_dict[ + f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + ].T.detach().clone() + down_proj[e] = down_proj_weights + del neuron_state_dict[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"] + + # Pad if needed + if pad_size > 0: + down_proj = F.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + + gc.collect() + + return neuron_state_dict + + +def _maybe_dequantize_layer( + neuron_state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, +): + """Dequantize FP8 layers if present.""" + scale_layers = [] + + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + + fp8_layer_name = layer_key.replace("_scale_inv", "") + if fp8_layer_name not in neuron_state_dict: + continue + + fp8_layer = neuron_state_dict[fp8_layer_name] + + # Get block size from config if available + if hasattr(config, 'quantization_config') and config.quantization_config: + block_size = config.quantization_config.get("weight_block_size", [128, 128]) + else: + block_size = [128, 128] + + # Expand scales and dequantize + scales_expanded = scales.repeat_interleave(block_size[0], dim=0) + scales_expanded = scales_expanded.repeat_interleave(block_size[1], dim=1) + + # Ensure shapes match + if scales_expanded.shape != fp8_layer.shape: + scales_expanded = scales_expanded[:fp8_layer.shape[0], :fp8_layer.shape[1]] + + scaled_layer = fp8_layer.to(torch.float32) * scales_expanded.to(torch.float32) + neuron_state_dict[fp8_layer_name] = scaled_layer.to(config.neuron_config.torch_dtype) + + # Remove scale layers + for scale_layer in scale_layers: + del neuron_state_dict[scale_layer] + + +class NeuronMiMoV2ForCausalLM(NeuronBaseForCausalLM): + """MiMo-V2-Flash for Causal Language Modeling on Neuron.""" + + _model_cls = NeuronMiMoV2Model + + @staticmethod + def load_hf_model(model_path: str, **kwargs): + """Load HuggingFace model. + + Note: MiMo-V2-Flash uses custom code, so we need trust_remote_code=True + """ + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + **kwargs, + ) + + @classmethod + def get_config_cls(cls) -> Type[MiMoV2InferenceConfig]: + return MiMoV2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: Dict[str, Any], + config: MiMoV2InferenceConfig, + ) -> Dict[str, Any]: + return convert_mimo_v2_hf_to_neuron_state_dict(state_dict, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self) -> str: + """Get compiler arguments optimized for MiMo-V2-Flash.""" + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + elif self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + optimization_level = "-O3" if self.neuron_config.moe_ep_degree > 1 else "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + f"--enable-saturate-infinity " + f"--enable-mixed-precision-accumulation " + f"--model-type transformer " + f"{optimization_level}" + ) + + # Add CC overlap optimization + compiler_args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + + compiler_args += " --auto-cast=none" + + # Enable vector-offset DGE + compiler_args += " --internal-enable-dge-levels vector_dynamic_offsets" + compiler_args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + compiler_args += f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size}" + + return compiler_args diff --git a/src/neuronx_distributed_inference/models/minimax_m2/__init__.py b/src/neuronx_distributed_inference/models/minimax_m2/__init__.py new file mode 100644 index 00000000..bbfb49f5 --- /dev/null +++ b/src/neuronx_distributed_inference/models/minimax_m2/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from neuronx_distributed_inference.models.minimax_m2.modeling_minimax_m2 import ( + MiniMaxM2InferenceConfig, + NeuronMiniMaxM2ForCausalLM, +) + +__all__ = [ + "MiniMaxM2InferenceConfig", + "NeuronMiniMaxM2ForCausalLM", +] diff --git a/src/neuronx_distributed_inference/models/minimax_m2/config.json b/src/neuronx_distributed_inference/models/minimax_m2/config.json new file mode 100644 index 00000000..237efe37 --- /dev/null +++ b/src/neuronx_distributed_inference/models/minimax_m2/config.json @@ -0,0 +1,112 @@ +{ + "architectures": [ + "MiniMaxM2ForCausalLM" + ], + "attention_dropout": 0.0, + "attn_type_list": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1 + ], + "auto_map": { + "AutoConfig": "configuration_minimax_m2.MiniMaxM2Config", + "AutoModelForCausalLM": "modeling_minimax_m2.MiniMaxM2ForCausalLM" + }, + "bos_token_id": null, + "eos_token_id": null, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 1536, + "layernorm_full_attention_beta": 1.0, + "layernorm_linear_attention_beta": 1.0, + "layernorm_mlp_beta": 1.0, + "max_position_embeddings": 196608, + "mlp_intermediate_size": 8192, + "model_type": "minimax_m2", + "mtp_transformer_layers": 1, + "num_attention_heads": 48, + "num_experts_per_tok": 8, + "num_hidden_layers": 62, + "num_key_value_heads": 8, + "num_local_experts": 256, + "num_mtp_modules": 3, + "output_router_logits": false, + "qk_norm_type": "per_layer", + "rms_norm_eps": 1e-06, + "rope_theta": 5000000, + "rotary_dim": 64, + "router_aux_loss_coef": 0.001, + "router_jitter_noise": 0.0, + "scoring_func": "sigmoid", + "shared_intermediate_size": 0, + "shared_moe_mode": "sigmoid", + "sliding_window": null, + "tie_word_embeddings": false, + "transformers_version": "4.57.1", + "use_cache": true, + "use_mtp": true, + "use_qk_norm": true, + "use_routing_bias": true, + "vocab_size": 200064 +} diff --git a/src/neuronx_distributed_inference/models/minimax_m2/configuration_minimax_m2.py b/src/neuronx_distributed_inference/models/minimax_m2/configuration_minimax_m2.py new file mode 100644 index 00000000..76ea7dcc --- /dev/null +++ b/src/neuronx_distributed_inference/models/minimax_m2/configuration_minimax_m2.py @@ -0,0 +1,201 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from transformers.configuration_utils import PretrainedConfig + + + +class MiniMaxM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an + MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1. + + [minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B) + [minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + + ```python + >>> from transformers import MiniMaxM2Model, MiniMaxM2Config + + >>> # Initializing a MiniMaxM2 7B style configuration + >>> configuration = MiniMaxM2Config() + + >>> # Initializing a model from the MiniMaxM2 7B style configuration + >>> model = MiniMaxM2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax_m2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + + self.use_qk_norm = kwargs.pop("use_qk_norm", False) + self.rotary_dim = kwargs.pop("rotary_dim", self.head_dim) + self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1) + if self.head_dim is not None: + self.partial_rotary_factor = self.rotary_dim / self.head_dim + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["MiniMaxM2Config"] \ No newline at end of file diff --git a/src/neuronx_distributed_inference/models/minimax_m2/modeling_minimax_m2.py b/src/neuronx_distributed_inference/models/minimax_m2/modeling_minimax_m2.py new file mode 100644 index 00000000..db4ba001 --- /dev/null +++ b/src/neuronx_distributed_inference/models/minimax_m2/modeling_minimax_m2.py @@ -0,0 +1,1388 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +MiniMax-M2 model for NeuronX Distributed Inference. + +Architecture: 229B total, ~10B active. 62 decoder layers, 256 MoE experts (top-8), +sigmoid routing with e_score_correction_bias, partial RoPE (64/128 head dim), +QK normalization (RMSNorm before reshape), GQA 48Q/8KV heads, SwiGLU experts. + +Based on Henan's (whn09) implementation with SDK 2.28 improvements: +- Fused MoE NKI kernels (router_topk, moe_cte, moe_tkg) +- ModuleMarker wrappers for compiler optimization +- Fused QKV support +- Shard-on-intermediate padding for blockwise matmul +- RouterTopKWithBias preserving e_score_correction_bias for accuracy +""" + +import gc +import math +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from neuronx_distributed.modules.moe.routing import RouterTopK +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils import cpu_mode + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MOE_TKG_MK_INTERMEDIATE_PER_TP, + MoENeuronConfig, + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, +) +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) + +# nki-library attention block kernel (partial RoPE support) +try: + from nkilib.experimental.transformer.attention_block_tkg import attention_block_tkg + from nkilib.core.utils.common_types import ( + QuantizationType as NkilibQuantizationType, + ) + + _HAS_NKILIB_ATTN_BLOCK = True +except ImportError: + _HAS_NKILIB_ATTN_BLOCK = False +from neuronx_distributed_inference.modules.attention.gqa import GQA +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import ( + initialize_moe_process_group, +) + +GQA_SHARDING_STRATEGY = GQA.REPLICATE_TO_TP_DEGREE + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def get_rmsnorm_cls(): + """Return the appropriate RMSNorm class for the execution environment.""" + if cpu_mode(): + + class SimpleRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + return self.weight * hidden_states.to(input_dtype) + + return SimpleRMSNorm + return CustomRMSNorm + + +def get_modules_to_not_convert(neuron_config: MoENeuronConfig): + return getattr(neuron_config, "modules_to_not_convert", None) + + +# --------------------------------------------------------------------------- +# Fused QKV helpers +# --------------------------------------------------------------------------- + + +def _helper_concat_and_delete_qkv( + state_dict: Dict[str, Any], layer_num: int, attr: str +): + """Concatenate Q/K/V into fused Wqkv for a single attribute (weight or scale). + + The fused key uses the ``qkv_proj.Wqkv`` path because the NxDI model nests + the Wqkv linear layer under ``self_attn.qkv_proj`` (a GroupQueryAttention_QKV module). + """ + state_dict[f"layers.{layer_num}.self_attn.qkv_proj.Wqkv.{attr}"] = torch.cat( + [ + state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"], + state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"], + state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"], + ], + ) + del state_dict[f"layers.{layer_num}.self_attn.q_proj.{attr}"] + del state_dict[f"layers.{layer_num}.self_attn.k_proj.{attr}"] + del state_dict[f"layers.{layer_num}.self_attn.v_proj.{attr}"] + + +def convert_state_dict_to_fused_qkv(state_dict: Dict[str, Any], cfg: InferenceConfig): + """Fuse separate Q/K/V weights into a single Wqkv tensor per layer.""" + mods_to_not_conv = get_modules_to_not_convert(cfg.neuron_config) or [] + for layer_idx in range(cfg.num_hidden_layers): + _helper_concat_and_delete_qkv(state_dict, layer_idx, "weight") + if ( + cfg.neuron_config.quantized_mlp_kernel_enabled + or cfg.neuron_config.quantized + ) and f"layers.{layer_idx}.self_attn" not in mods_to_not_conv: + _helper_concat_and_delete_qkv(state_dict, layer_idx, "scale") + gc.collect() + return state_dict + + +def maybe_dequantize_layer(neuron_state_dict: dict, config): + """Dequantize FP8 layers (weight_scale_inv) to the configured torch dtype.""" + scale_layers = [] + for layer_key in list(neuron_state_dict.keys()): + if "_scale_inv" in layer_key: + scales = neuron_state_dict[layer_key] + scale_layers.append(layer_key) + fp8_layer_name = layer_key.replace("_scale_inv", "") + fp8_layer = neuron_state_dict[fp8_layer_name] + block_size = config.quantization_config["weight_block_size"] + scales_expanded = scales.repeat_interleave( + block_size[0], dim=0 + ).repeat_interleave(block_size[1], dim=1) + scaled_layer = fp8_layer.to(torch.float32) * scales_expanded.to( + torch.float32 + ) + neuron_state_dict[fp8_layer_name] = scaled_layer.to( + config.neuron_config.torch_dtype + ) + for key in scale_layers: + del neuron_state_dict[key] + + +# --------------------------------------------------------------------------- +# MiniMax-M2 specific modules +# --------------------------------------------------------------------------- + + +class MiniMaxM2QKNorm(nn.Module): + """ + QK normalization for MiniMax-M2 using Neuron's fused RmsNorm custom call. + + MiniMax-M2 applies RMSNorm on the Q/K projection output before reshape. + This implementation uses the Neuron-native AwsNeuronRmsNorm custom call + (via RmsNorm.apply) which is validated for both context encoding and token + generation NEFFs. Hand-rolled PyTorch RMSNorm (pow/mean/rsqrt) compiles + into different HLO in CE vs TG and produces incorrect TG results. + + Normalization is computed per-rank (no all-reduce) on the flat projection + output [B, S, per_rank_dim]. The per-element weight is selected dynamically + by SPMD rank from a padded weight tensor. + + Args: + hidden_size: Per-rank hidden dimension (num_heads_per_rank * head_dim) + eps: Epsilon for numerical stability + tp_degree: Tensor parallelism degree + padded_hidden_size: Total weight storage size (tp_degree * per_rank_size) + """ + + def __init__( + self, + hidden_size, + eps=1e-6, + tp_degree=1, + padded_hidden_size=None, + ): + super().__init__() + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.tp_degree = tp_degree + self.padded_hidden_size = ( + padded_hidden_size + if padded_hidden_size is not None + else (hidden_size * tp_degree) + ) + # Weight stored at full padded size for SPMD rank-based selection + self.weight = nn.Parameter(torch.ones(self.padded_hidden_size)) + + def forward(self, hidden_states, rank_util=None): + """ + Apply Neuron-native RMSNorm on flat Q or K tensor (no all-reduce). + + Args: + hidden_states: [B, S, per_rank_dim] — flat projection output + rank_util: SPMDRank for dynamic weight slice selection + """ + from neuronx_distributed_inference.modules.custom_calls import RmsNorm + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + # Dynamically select weight slice by SPMD rank (XLA-compatible) + if rank_util is not None and self.tp_degree > 1: + weight_reshaped = self.weight.view(self.tp_degree, self.hidden_size) + rank_index = rank_util.rank[:1] + local_weight = torch.index_select(weight_reshaped, 0, rank_index).squeeze(0) + else: + local_weight = self.weight[: self.hidden_size] + + # Use Neuron-native fused RmsNorm (AwsNeuronRmsNorm custom call) + dim = len(hidden_states.shape) - 1 + result = RmsNorm.apply(hidden_states, local_weight, self.variance_epsilon, dim) + + return result.to(input_dtype) + + +class RouterTopKWithBias(RouterTopK): + """ + RouterTopK with e_score_correction_bias for MiniMax-M2 sigmoid routing. + + MiniMax-M2 applies sigmoid to router logits to obtain expert affinities, then + adds a learned per-expert bias before top-K selection. The bias influences which + experts are chosen but does NOT affect the affinity weights passed to experts. + + The bias MUST be an nn.Parameter (not a buffer) because: + - XLA tracing bakes register_buffer values as constants in the NEFF + - shard_children only processes nn.Parameter in supported modules + - replace_weights only loads tensors present in the traced model's separated weights + Using nn.Parameter ensures the bias is separated during tracing and loaded from + the checkpoint at inference time. + + Dropping the bias (as v3 does for XLA simplicity) causes ~75% wrong expert selection + because bias values (~8.0-9.5) dominate sigmoid scores (0-1). + """ + + def __init__(self, num_experts: int, *args, **kwargs): + super().__init__(num_experts=num_experts, *args, **kwargs) + # nn.Parameter so it gets separated from NEFF and loaded from checkpoint. + # requires_grad=False since this is inference-only. + # CRITICAL: Initialize with non-uniform values to prevent XLA graph optimization + # from eliminating the add-bias operation. Uniform values (zeros, ones) don't + # change relative ordering in topk, so XLA can prove the add is a no-op and + # eliminate it — removing the bias parameter from the HLO entirely and making it + # impossible to load the real bias values at inference time. + # Using arange produces distinct per-expert values that genuinely affect topk + # ordering, forcing the compiler to keep the bias as a runtime parameter. + # IMPORTANT: Initialize as bfloat16 to match the dtype that _cast_helper + # will produce from the checkpoint (FP32 → BF16). If the NEFF expects FP32 + # but the checkpoint provides BF16, the LayoutTransformation silently + # ignores the weight and leaves the trace-time values in place. + self.e_score_correction_bias = nn.Parameter( + torch.arange(num_experts, dtype=torch.bfloat16), + requires_grad=False, + ) + + def forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + + # Add bias for expert selection only (MiniMax-M2 specific). + # sigmoid(logits) + bias determines WHICH experts are selected, + # but the un-biased sigmoid scores are used as affinity weights. + scores_for_choice = ( + expert_affinities.float() + self.e_score_correction_bias.unsqueeze(0) + ) + _, expert_index = torch.topk(scores_for_choice, self.top_k, dim=-1) + + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + expert_index = expert_index.detach().to(dtype=torch.long) + + return router_logits, expert_affinities, expert_index + + +# --------------------------------------------------------------------------- +# MoE initialization +# --------------------------------------------------------------------------- + + +def initialize_minimax_m2_moe_module( + config: InferenceConfig, rmsnorm=None, init_tkg_module=False +): + """ + Create the MoE module for MiniMax-M2 with e_score_correction_bias. + + Instead of wrapping the standard MoE, we inject a RouterTopKWithBias directly + as the router. This ensures the bias is an nn.Parameter that gets: + 1. Separated from the NEFF during XLA tracing (not baked as a constant) + 2. Loaded from the checkpoint via replace_weights at inference time + + The bias values (~8.0-9.5) dominate sigmoid scores (0-1) and are critical + for correct expert selection. Without them, ~75% of experts are wrong. + """ + from neuronx_distributed.modules.moe.expert_mlps_v2 import ExpertMLPsV2 + from neuronx_distributed.modules.moe.model import MoE + from neuronx_distributed.modules.moe.moe_configs import RoutedExpertsMLPOpsConfig + from neuronx_distributed.parallel_layers import parallel_state + from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + get_tensor_model_parallel_group, + get_world_group, + ) + + from neuronx_distributed_inference.modules.moe_v2 import ( + initialize_moe_process_group, + ) + + enabled_hybrid_sharding = config.neuron_config.hybrid_sharding_config is not None + ( + moe_tkg_tensor_model_parallel_group, + moe_tkg_expert_model_parallel_group, + moe_cte_tensor_model_parallel_group, + moe_cte_expert_model_parallel_group, + ) = initialize_moe_process_group(config, enabled_hybrid_sharding) + + # Use RouterTopKWithBias instead of standard RouterTopK + router = RouterTopKWithBias( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + dtype=config.neuron_config.router_config.dtype, + act_fn=config.neuron_config.router_config.act_fn, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + sequence_dimension=1, + bias=False, # no linear bias; we use e_score_correction_bias instead + apply_act_fn_over_topk=False, + store_transposed_weights=init_tkg_module, + ) + + hidden_size_actual = getattr(config, "original_hidden_size", None) + intermediate_size_actual = getattr(config, "original_intermediate_size", None) + + expert_mlps = ExpertMLPsV2( + routed_experts_mlp_config=RoutedExpertsMLPOpsConfig( + num_experts=config.num_local_experts, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_size_actual=hidden_size_actual, + intermediate_size_actual=intermediate_size_actual, + is_hidden_dim_shuffled=config.neuron_config.is_hidden_dim_shuffled, + is_intermediate_dim_shuffled=config.neuron_config.is_intermediate_dim_shuffled, + top_k=config.num_experts_per_tok, + hidden_act=config.hidden_act, + bias=False, + glu_mlp=config.neuron_config.glu_mlp, + glu_type=config.neuron_config.glu_type, + hidden_act_scaling_factor=config.neuron_config.hidden_act_scaling_factor, + hidden_act_bias=config.neuron_config.hidden_act_bias, + use_index_calc_kernel=config.neuron_config.use_index_calc_kernel, + gate_clamp_upper_limit=config.neuron_config.gate_clamp_upper_limit, + gate_clamp_lower_limit=config.neuron_config.gate_clamp_lower_limit, + up_clamp_upper_limit=config.neuron_config.up_clamp_upper_limit, + up_clamp_lower_limit=config.neuron_config.up_clamp_lower_limit, + early_expert_affinity_modulation=config.neuron_config.early_expert_affinity_modulation, + normalize_top_k_affinities=config.neuron_config.normalize_top_k_affinities, + enable_spmd_rank=config.neuron_config.blockwise_matmul_config.parallelize_token_to_block_mapping, + ), + blockwise_matmul_config=config.neuron_config.blockwise_matmul_config, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + dtype=config.neuron_config.torch_dtype, + is_prefill=config.neuron_config.is_prefill_stage, + enabled_hybrid_sharding=enabled_hybrid_sharding, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + expert_model_parallel_group=parallel_state.get_expert_model_parallel_group(), + cte_tensor_model_parallel_group=moe_cte_tensor_model_parallel_group, + cte_expert_model_parallel_group=moe_cte_expert_model_parallel_group, + tkg_tensor_model_parallel_group=moe_tkg_tensor_model_parallel_group, + tkg_expert_model_parallel_group=moe_tkg_expert_model_parallel_group, + ) + + if init_tkg_module: + from neuronx_distributed.modules.moe.model import MoEFusedTKGConfig + + tkg_config = MoEFusedTKGConfig( + quantized=config.neuron_config.quantized, + moe_fused_kernel_enabled=config.neuron_config.moe_fused_nki_kernel_enabled, + router_topk_kernel_enabled=config.neuron_config.router_topk_nki_kernel_enabled, + expert_mlp_kernel_enabled=config.neuron_config.expert_mlp_nki_kernel_enabled, + shared_mlp_kernel_enabled=config.neuron_config.shared_mlp_nki_kernel_enabled, + norm_topk_prob=config.neuron_config.normalize_top_k_affinities, + is_mxfp4_compute=config.neuron_config.is_mxfp4_compute, + router_mm_dtype=config.neuron_config.router_config.dtype, + ) + else: + tkg_config = None + + moe = MoE( + router=router, + expert_mlps=expert_mlps, + shared_experts=None, # MiniMax-M2 has no shared experts + rmsnorm=rmsnorm, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + return_expert_index=config.neuron_config.return_expert_index, + return_router_logits=config.neuron_config.return_router_logits, + sequence_dimension=1, + init_tkg_module=init_tkg_module, + tkg_config=tkg_config, + ) + + moe.eval() + return moe + + +# --------------------------------------------------------------------------- +# Weight conversion +# --------------------------------------------------------------------------- + + +def convert_minimax_m2_hf_to_neuron_state_dict( + neuron_state_dict: Dict[str, Any], + config: "MiniMaxM2InferenceConfig", +) -> Dict[str, Any]: + """ + Convert a HuggingFace MiniMax-M2 checkpoint to the NxDI-compatible format. + + Key transformations: + 1. Stack per-expert w1/w3 into gate_up_proj, w2 into down_proj + 2. Rename router gate -> router.linear_router (or router.e_score_correction_bias) + 3. Pad QK norm weights to match TP sharding (interleaved for Q, replicated for K) + 4. Optionally pad intermediate_size for shard-on-I blockwise matmul + 5. Optionally fuse QKV into Wqkv + """ + from neuronx_distributed_inference.modules.attention.gqa import ( + GQA, + _maybe_pad_interleaved, + get_shardable_head_counts, + ) + + assert config.neuron_config.glu_mlp is True, ( + "MiniMax-M2 requires glu_mlp=True (SwiGLU)" + ) + + # Dequantize FP8 weights if present + maybe_dequantize_layer(neuron_state_dict, config) + + with torch.no_grad(): + tp_degree = config.neuron_config.tp_degree + head_dim = config.head_dim + has_qk_norm = getattr(config, "use_qk_norm", True) + + # Rank utility tensor for SPMD operations (int32 for NKI compatibility) + rank_tensor = torch.arange(0, tp_degree, dtype=torch.int32) + neuron_state_dict["rank_util.rank"] = rank_tensor + + # Pre-compute sharded head counts for QK norm padding + sharding_strategy = GQA.REPLICATE_TO_TP_DEGREE + padded_num_attention_heads, padded_num_kv_heads = get_shardable_head_counts( + tp_degree, + config.num_attention_heads, + config.num_key_value_heads, + sharding_strategy, + ) + + gc_interval = 64 # GC every N experts to control memory + + for layer_idx in range(config.num_hidden_layers): + # Per-layer rank tensor for attention SPMD + neuron_state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = ( + rank_tensor.clone() + ) + + # --- QK norm weight padding --- + if has_qk_norm: + # Q norm: interleaved padding (48 -> padded heads) + q_norm_key = f"layers.{layer_idx}.self_attn.q_norm.weight" + if q_norm_key in neuron_state_dict: + q_norm_full = neuron_state_dict[q_norm_key] + source_group_size = ( + config.num_attention_heads // config.num_key_value_heads + ) + q_norm_padded = _maybe_pad_interleaved( + q_norm_full.unsqueeze(0), + pad_dim=1, + source_heads=config.num_attention_heads, + target_heads=padded_num_attention_heads, + source_group_size=source_group_size, + ).squeeze(0) + neuron_state_dict[q_norm_key] = q_norm_padded + + # K norm: replicate from original KV heads to padded KV heads + k_norm_key = f"layers.{layer_idx}.self_attn.k_norm.weight" + if k_norm_key in neuron_state_dict: + k_norm_full = neuron_state_dict[k_norm_key] + k_norm_reshaped = k_norm_full.reshape( + config.num_key_value_heads, head_dim + ) + repeats = padded_num_kv_heads // config.num_key_value_heads + k_norm_replicated = k_norm_reshaped.repeat_interleave( + repeats, dim=0 + ) + neuron_state_dict[k_norm_key] = k_norm_replicated.reshape(-1) + + # --- Router weights --- + gate_key = f"layers.{layer_idx}.block_sparse_moe.gate.weight" + router_key = ( + f"layers.{layer_idx}.block_sparse_moe.router.linear_router.weight" + ) + neuron_state_dict[router_key] = neuron_state_dict.pop(gate_key) + + # e_score_correction_bias: map to RouterTopKWithBias.e_score_correction_bias + # This is an nn.Parameter in the router, so it will be separated from the + # NEFF during tracing and loaded via replace_weights at inference time. + bias_src_key = ( + f"layers.{layer_idx}.block_sparse_moe.e_score_correction_bias" + ) + bias_dst_key = ( + f"layers.{layer_idx}.block_sparse_moe.router.e_score_correction_bias" + ) + if bias_src_key in neuron_state_dict: + neuron_state_dict[bias_dst_key] = neuron_state_dict.pop(bias_src_key) + + # --- Expert weight stacking --- + w1_key = f"layers.{layer_idx}.block_sparse_moe.experts.0.w1.weight" + intermediate_size, hidden_size = neuron_state_dict[w1_key].shape + device = neuron_state_dict[w1_key].device + dtype = neuron_state_dict[w1_key].dtype + + # Stack gate (w1) + up (w3) into gate_up_proj: [E, H, 2*I] + gate_up_proj = torch.empty( + config.num_local_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + for expert_idx in range(config.num_local_experts): + ew1 = f"layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight" + ew3 = f"layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight" + + gate_up_slice = torch.narrow(gate_up_proj, 0, expert_idx, 1) + torch.narrow(gate_up_slice, 2, 0, intermediate_size).copy_( + neuron_state_dict[ew1].T + ) + torch.narrow( + gate_up_slice, 2, intermediate_size, intermediate_size + ).copy_(neuron_state_dict[ew3].T) + del neuron_state_dict[ew1], neuron_state_dict[ew3] + if (expert_idx + 1) % gc_interval == 0: + gc.collect() + + # Pad gate_up_proj intermediate dimension if needed for shard-on-I + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape( + config.num_local_experts, hidden_size, 2, -1 + ) + gate_up_proj = torch.nn.functional.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape( + config.num_local_experts, hidden_size, -1 + ) + + neuron_state_dict[ + f"layers.{layer_idx}.block_sparse_moe.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up_proj + + # Stack down (w2) into down_proj: [E, I, H] + down_proj = torch.empty( + config.num_local_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + for expert_idx in range(config.num_local_experts): + ew2 = f"layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight" + torch.narrow(down_proj, 0, expert_idx, 1).copy_( + neuron_state_dict[ew2].T + ) + del neuron_state_dict[ew2] + if (expert_idx + 1) % gc_interval == 0: + gc.collect() + + if pad_size > 0: + down_proj = torch.nn.functional.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[ + f"layers.{layer_idx}.block_sparse_moe.expert_mlps.mlp_op.down_proj.weight" + ] = down_proj + + gc.collect() + + # Fuse QKV if configured (must run BEFORE the rename below, since + # convert_state_dict_to_fused_qkv expects layers.X.self_attn.q_proj.weight) + if config.neuron_config.fused_qkv: + neuron_state_dict = convert_state_dict_to_fused_qkv( + neuron_state_dict, config + ) + + # --- Attention projection key renaming --- + # The NxDI traced model uses nested module names for attention projections: + # self_attn.qkv_proj.q_proj.weight (not self_attn.q_proj.weight) + # self_attn.qkv_proj.k_proj.weight (not self_attn.k_proj.weight) + # self_attn.qkv_proj.v_proj.weight (not self_attn.v_proj.weight) + # self_attn.o_proj.o_proj.weight (not self_attn.o_proj.weight) + # The preshard hook in RowParallelLinear handles the o_proj rename + # (o_proj.weight -> o_proj.o_proj.weight), so we only rename Q/K/V here. + # When fused_qkv=True, Q/K/V are already merged into Wqkv above. + for layer_idx in range(config.num_hidden_layers): + prefix = f"layers.{layer_idx}.self_attn" + # Q/K/V projections -> nested under qkv_proj + for proj in ("q_proj", "k_proj", "v_proj"): + old_key = f"{prefix}.{proj}.weight" + new_key = f"{prefix}.qkv_proj.{proj}.weight" + if old_key in neuron_state_dict: + neuron_state_dict[new_key] = neuron_state_dict.pop(old_key) + + return neuron_state_dict + + +# --------------------------------------------------------------------------- +# Inference config +# --------------------------------------------------------------------------- + + +class MiniMaxM2InferenceConfig(InferenceConfig): + """ + Inference configuration for MiniMax-M2. + + Extends InferenceConfig with MoE-specific setup: + - Sigmoid routing with FP32 router precision + - Intermediate-size padding for shard-on-I blockwise matmul + - Fused MoE NKI kernel enablement + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # MiniMax-M2 has no shared experts + self.n_shared_experts = 0 + + # Store MoE intermediate size before any padding + self.moe_intermediate_size = self.intermediate_size + + # Pad intermediate for shard-on-I compatibility + self.moe_intermediate_pad_size = 0 + self._maybe_pad_intermediate() + + # Enable fused MoE NKI kernels where dimensions allow + self._enable_moe_fused_nki_kernel() + + # Router config: MiniMax-M2 uses sigmoid routing with FP32 precision + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" + + # MiniMax-M2 normalizes top-K affinities + self.neuron_config.normalize_top_k_affinities = True + + # Disable numeric CC token for MoE stability + self.neuron_config.disable_numeric_cc_token = True + + def _maybe_pad_intermediate(self): + """Pad intermediate_size so shard-on-I blockwise matmul kernels tile correctly.""" + moe_tp_degree = self.neuron_config.moe_tp_degree + i_tp = self.intermediate_size // moe_tp_degree + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if i_tp % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded = ( + math.ceil(i_tp / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max(padded - self.intermediate_size, 0) + self.intermediate_size = padded + + def _enable_moe_fused_nki_kernel(self): + """Enable fused MoE NKI kernel if the per-TP intermediate dimension is aligned.""" + i_tp = self.intermediate_size // self.neuron_config.moe_tp_degree + if getattr(self.neuron_config, "moe_fused_nki_kernel_enabled", False): + if i_tp % MOE_TKG_MK_INTERMEDIATE_PER_TP == 0: + self.moe_fused_nki_kernel_enabled = True + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "num_local_experts", + "rms_norm_eps", + "rope_theta", + "tie_word_embeddings", + "vocab_size", + "use_qk_norm", + "rotary_dim", + ] + + @classmethod + def get_neuron_config_cls(cls): + return MoENeuronConfig + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class NeuronMiniMaxM2Attention(NeuronAttentionBase): + """ + MiniMax-M2 attention with two non-standard features: + + 1. QK normalization applied BEFORE reshape to per-head layout (on the full + Q/K projection output). Uses MiniMaxM2QKNorm with distributed all-reduce. + 2. Partial RoPE: rotary embeddings applied to only the first ``rotary_dim`` + dimensions of each head (64 out of 128). + """ + + def __init__(self, config: MiniMaxM2InferenceConfig): + self.rotary_dim = getattr(config, "rotary_dim", config.head_dim) + + # RotaryEmbedding sized to rotary_dim (64), not head_dim (128) + rotary_emb = RotaryEmbedding( + self.rotary_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=config.rms_norm_eps, + use_qk_norm=False, # handled by MiniMaxM2QKNorm below + ) + + # --- QK normalization (local per-rank, no all-reduce) --- + self.use_minimax_qk_norm = getattr(config, "use_qk_norm", True) + tp_degree = config.neuron_config.tp_degree + + if self.use_minimax_qk_norm: + q_per_rank = self.num_heads * self.head_dim + k_per_rank = self.num_key_value_heads * self.head_dim + + # Weight storage: padded to tp_degree * per_rank for SPMD selection + padded_q = self.num_heads * tp_degree * config.head_dim + padded_kv = self.num_key_value_heads * tp_degree + padded_k = padded_kv * config.head_dim + + self.q_norm = MiniMaxM2QKNorm( + q_per_rank, + eps=config.rms_norm_eps, + tp_degree=tp_degree, + padded_hidden_size=padded_q, + ) + self.k_norm = MiniMaxM2QKNorm( + k_per_rank, + eps=config.rms_norm_eps, + tp_degree=tp_degree, + padded_hidden_size=padded_k, + ) + + if not parallel_state.model_parallel_is_initialized(): + raise ValueError( + "NeuronMiniMaxM2Attention requires an initialized distributed environment. " + "Use neuronx_distributed to initialize." + ) + + def prep_qkv_tensors( + self, + position_ids, + hidden_states, + past_key_value, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + skip_rope=False, + residual=None, + use_polar_compatible_rope=False, + ): + """Apply local QK norm on flat projection, reshape to heads, then partial RoPE.""" + Q, K, V, residual = self.get_qkv_proj()( + hidden_states=hidden_states, + rmsnorm=rmsnorm, + adapter_ids=adapter_ids, + residual=residual, + ) + + # QK norm on flat per-rank projection output BEFORE reshape (no all-reduce) + if self.use_minimax_qk_norm: + Q = self.q_norm(Q, self.rank_util) + K = self.k_norm(K, self.rank_util) + + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + # Reshape to [B, S, num_heads, head_dim] then transpose to [B, H, S, D] + Q = ( + Q.view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + K = ( + K.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + V = ( + V.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + if not skip_rope: + Q, K, cos_cache, sin_cache = self.apply_rotary_embedding( + Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ) + + return Q, K, V, cos_cache, sin_cache, residual + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Apply partial rotary embeddings (first rotary_dim dimensions only).""" + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if not use_polar_compatible_rope and self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + if self.rotary_dim < self.head_dim: + Q_rot, Q_pass = Q[..., : self.rotary_dim], Q[..., self.rotary_dim :] + K_rot, K_pass = K[..., : self.rotary_dim], K[..., self.rotary_dim :] + Q_rot, K_rot = apply_rotary_pos_emb(Q_rot, K_rot, cos_cache, sin_cache) + Q = torch.cat([Q_rot, Q_pass], dim=-1) + K = torch.cat([K_rot, K_pass], dim=-1) + else: + Q, K = apply_rotary_pos_emb(Q, K, cos_cache, sin_cache) + + return Q, K, cos_cache, sin_cache + + def attention_block_tokengen_nki_kernel( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + update_kv_per_layer=True, + active_block_table=None, + use_polar_compatible_rope=False, + ): + """ + Override base class to use nki-library attention_block_tkg kernel with + partial RoPE support (rotary_dim < head_dim). + + Uses the nki-library kernel instead of the compiler's private kernel. + QK norm is fused into the kernel via the flat QK RMSNorm feature, which + normalizes across all Q (or K) heads concatenated before head splitting. + """ + assert _HAS_NKILIB_ATTN_BLOCK, ( + "nki-library attention_block_tkg not available. " + "Install the nki-library fork with partial RoPE support." + ) + + from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, + gather_from_tensor_model_parallel_region_with_dim, + reduce_scatter_to_tensor_model_parallel_region_with_dim, + ) + from neuronx_distributed_inference.modules.attention.attention_base import ( + EPDispatchOption, + get_data_parallel_attention_dp_group, + ) + # NKI 0.3.0: use kernel[lnc_int] instead of kernel[(nc(lnc),)] + + if ( + self.sequence_parallel_enabled + and self.tensor_model_parallel_group is not None + ): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + # Get shapes + bsz, s_tkg, h = hidden_states.shape + h_out = h // 2 if self.is_eagle3_draft else h + num_q_heads = self.num_heads + + # Prepare rmsnorm params + rmsnorm_enabled = rmsnorm is not None + W_gamma = rmsnorm.weight.data.unsqueeze(0) if rmsnorm is not None else None + + # Prepare RoPE params + rope_contiguous_layout = not use_polar_compatible_rope + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb( + hidden_states, rotary_position_ids + ) + # Take first half and reshape to [dim//2, batch_size, seq_len] + cos_cache = cos_cache[..., : cos_cache.shape[-1] // 2].permute(2, 0, 1) + sin_cache = sin_cache[..., : sin_cache.shape[-1] // 2].permute(2, 0, 1) + elif use_polar_compatible_rope: + from neuronx_distributed.modules.attention.utils import precompute_freqs_cis + + rotary_freqs = precompute_freqs_cis( + self.head_dim, + self.neuron_config.max_context_length * 2, + self.rope_theta, + self.use_scaled_rope, + device=hidden_states.device, + ) + rotary_freqs = rotary_freqs[position_ids] + cos_cache = rotary_freqs.cos().permute(2, 0, 1) + sin_cache = rotary_freqs.sin().permute(2, 0, 1) + else: + cos_cache = None + sin_cache = None + + # Prepare attention mask: merge active_mask and transpose for kernel layout + attention_mask = attention_mask.expand(-1, num_q_heads, -1, -1) + expected_active_mask_shape = (bsz, 1, s_tkg, s_tkg) + if s_tkg == 1: + active_mask = torch.ones( + expected_active_mask_shape, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + else: + assert active_mask.shape == expected_active_mask_shape, ( + f"{active_mask.shape} != {expected_active_mask_shape}" + ) + active_mask = active_mask.expand(-1, num_q_heads, -1, -1) + attention_mask[:, :, :, -s_tkg:] = active_mask + # Transpose to [S_ctx, B, q_heads, S_tkg] for nki-library kernel + attention_mask = attention_mask.permute(3, 0, 1, 2) + + # Prepare KV cache + K_prior, V_prior = past_key_value[:2] + K_prior = K_prior.data + V_prior = V_prior.data + update_cache_in_kernel = ( + update_kv_per_layer and self.attn_block_tkg_nki_kernel_cache_update + ) + sink = ( + self.get_learned_sinks().data.unsqueeze(-1) + if self.learned_sinks_size is not None + else None + ) + kv_cache_update_idx = position_ids[:, :1].to(torch.int32) + + # Prepare output projection + W_out = self.get_o_proj().o_proj.weight.data + if self.o_bias: + W_out_bias = ( + self.get_o_proj().o_proj.bias.data / self.tp_degree + ).unsqueeze(0) + else: + W_out_bias = None + + # Prepare QKV projection + W_qkv = self.get_qkv_proj().Wqkv.weight.data + bias_qkv = ( + self.get_qkv_proj().Wqkv.bias.data.unsqueeze(0) if self.qkv_bias else None + ) + + grid = self.logical_nc_config + + # Prepare flat QK norm weights (per-rank slice via SPMD rank selection) + # The kernel expects [1, per_rank_width] weights for each of Q and K. + flat_qk_norm_enabled = self.use_minimax_qk_norm + flat_qk_W_Q = None + flat_qk_W_K = None + if flat_qk_norm_enabled: + # Q norm: select per-rank slice from padded weight + q_norm_weight = self.q_norm.weight.data # [padded_q_hidden_size] + q_per_rank = self.q_norm.hidden_size + if self.q_norm.tp_degree > 1: + q_w_reshaped = q_norm_weight.view(self.q_norm.tp_degree, q_per_rank) + rank_index = self.rank_util.rank[:1] + flat_qk_W_Q = torch.index_select( + q_w_reshaped, 0, rank_index + ) # [1, q_per_rank] + else: + flat_qk_W_Q = q_norm_weight[:q_per_rank].unsqueeze(0) # [1, q_per_rank] + + # K norm: select per-rank slice from padded weight + k_norm_weight = self.k_norm.weight.data # [padded_k_hidden_size] + k_per_rank = self.k_norm.hidden_size + if self.k_norm.tp_degree > 1: + k_w_reshaped = k_norm_weight.view(self.k_norm.tp_degree, k_per_rank) + rank_index = self.rank_util.rank[:1] + flat_qk_W_K = torch.index_select( + k_w_reshaped, 0, rank_index + ) # [1, k_per_rank] + else: + flat_qk_W_K = k_norm_weight[:k_per_rank].unsqueeze(0) # [1, k_per_rank] + + attn_output, K, V = attention_block_tkg[grid]( + # -- input + X=hidden_states, + X_hidden_dim_actual=getattr(self.config, "original_hidden_size", None), + # -- rmsnorm X + rmsnorm_X_enabled=rmsnorm_enabled, + rmsnorm_X_eps=self.rms_norm_eps, + rmsnorm_X_gamma=W_gamma, + # -- qkv projections + W_qkv=W_qkv, + bias_qkv=bias_qkv, + quantization_type_qkv=NkilibQuantizationType.NONE, + weight_dequant_scale_qkv=None, + input_dequant_scale_qkv=None, + # -- Q/K processing: flat QK RMSNorm (before head split) + rmsnorm_QK_flat_enabled=flat_qk_norm_enabled, + rmsnorm_QK_flat_eps=self.rms_norm_eps if flat_qk_norm_enabled else 0.0, + rmsnorm_QK_flat_W_Q=flat_qk_W_Q, + rmsnorm_QK_flat_W_K=flat_qk_W_K, + # -- Q/K processing: per-head pre-RoPE RMSNorm (disabled) + rmsnorm_QK_pre_rope_enabled=False, + rmsnorm_QK_pre_rope_eps=0.0, + rmsnorm_QK_pre_rope_W_Q=None, + rmsnorm_QK_pre_rope_W_K=None, + # -- Q/K processing: RoPE with partial rotary_dim + cos=cos_cache, + sin=sin_cache, + rope_contiguous_layout=rope_contiguous_layout, + rotary_dim=self.rotary_dim, + # -- Q/K processing: post-RoPE RMSNorm (disabled) + rmsnorm_QK_post_rope_enabled=False, + rmsnorm_QK_post_rope_eps=0.0, + rmsnorm_QK_post_rope_W_Q=None, + rmsnorm_QK_post_rope_W_K=None, + # -- attention + K_cache_transposed=self.k_cache_transposed, + active_blocks_table=( + active_block_table.to(torch.uint32) + if active_block_table is not None + else None + ), + K_cache=K_prior, + V_cache=V_prior, + attention_mask=attention_mask, + sink=sink, + softmax_scale=None, + # -- KV cache update + update_cache=update_cache_in_kernel, + kv_cache_update_idx=kv_cache_update_idx, + # -- output projection + W_out=W_out, + bias_out=W_out_bias, + quantization_type_out=NkilibQuantizationType.NONE, + weight_dequant_scale_out=None, + input_dequant_scale_out=None, + transposed_out=False, + # -- output + out_in_sb=False, + ) + + # Reshape and reduce output + attn_output = attn_output.reshape((bsz, s_tkg, h_out)) + if self.sequence_parallel_enabled: + attn_output = reduce_scatter_to_sequence_parallel_region( + attn_output, 1, process_group=self.tensor_model_parallel_group + ) + else: + if self.ep_dispatch_cc_option == EPDispatchOption.AR_AG: + attn_output = reduce_from_tensor_model_parallel_region( + attn_output, process_group=self.tensor_model_parallel_group + ) + elif self.ep_dispatch_cc_option == EPDispatchOption.RS_AG: + attn_output = reduce_scatter_to_tensor_model_parallel_region_with_dim( + attn_output, + partition_dim=0, + process_group=self.tensor_model_parallel_group, + ) + elif self.ep_dispatch_cc_option == EPDispatchOption.AG_AR: + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=0, + process_group=get_data_parallel_attention_dp_group(), + ) + else: + raise ValueError( + f"Unknown EPDispatchOption: {self.ep_dispatch_cc_option}" + ) + + # KV cache handling + if update_cache_in_kernel: + KV = past_key_value + else: + # Reshape K/V from kernel output layout to the rank-4 [B, N, S, D] + # layout expected by kv_cache_manager.update_kv_by_layer_id. + # K from kernel: [head_dim, bsz, q_len] (dBS) + # V from kernel: [bsz, q_len, head_dim] (BSd) + # Target: [B, 1, S, D] (BNSd) or [B, 1, D, S] (BNdS) for transposed K + K = K.permute(1, 0, 2) if self.k_cache_transposed else K.permute(1, 2, 0) + K = K.unsqueeze(1) + V = V.unsqueeze(1) + KV = (K, V) + + return attn_output, KV, cos_cache, sin_cache + + +class NeuronMiniMaxM2DecoderLayer(nn.Module): + """MiniMax-M2 decoder layer: attention + MoE with ModuleMarker wrappers.""" + + def __init__(self, config: MiniMaxM2InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronMiniMaxM2Attention(config=config) + self.moe_fused_nki_kernel_enabled = getattr( + config, "moe_fused_nki_kernel_enabled", False + ) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Fused MoE kernel absorbs post-attention layernorm + if self.moe_fused_nki_kernel_enabled: + self.block_sparse_moe = initialize_minimax_m2_moe_module( + config=config, + rmsnorm=self.post_attention_layernorm, + init_tkg_module=True, + ) + else: + self.block_sparse_moe = initialize_minimax_m2_moe_module(config=config) + + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + + qkv_fused_rmsnorm = None + if self.input_layernorm: + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + + # Self-attention + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MoE + residual = hidden_states + if not self.moe_fused_nki_kernel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states, padding_mask)[0] + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class NeuronMiniMaxM2Model(NeuronBaseModel): + """Traceable MiniMax-M2 base model.""" + + def setup_attr_for_model(self, config: MiniMaxM2InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: MiniMaxM2InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronMiniMaxM2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + +# --------------------------------------------------------------------------- +# CausalLM wrapper +# --------------------------------------------------------------------------- + + +class NeuronMiniMaxM2ForCausalLM(NeuronBaseForCausalLM): + """MiniMax-M2 causal language model for NxDI inference.""" + + _model_cls = NeuronMiniMaxM2Model + + @staticmethod + def load_hf_model(model_path, **kwargs): + return None + + @classmethod + def get_config_cls(cls): + return MiniMaxM2InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: MiniMaxM2InferenceConfig + ) -> dict: + return convert_minimax_m2_hf_to_neuron_state_dict(state_dict, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self): + """Compiler arguments tuned for MiniMax-M2 MoE. + + Uses -O1 by default. -O2 was tested but provides no scratchpad memory + savings vs -O1 (identical 22 GB tensor allocation at 62 layers TP=32). + """ + if self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + opt_level = "-O1" + else: + opt_level = "-O1" + + args = f"--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer {opt_level}" + args += ( + " --tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2'" + ) + args += " --auto-cast=none" + args += " --internal-enable-dge-levels vector_dynamic_offsets" + args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + args += ( + f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size}" + ) + + if self.neuron_config.attn_block_tkg_nki_kernel_enabled: + assert self.neuron_config.attn_block_tkg_nki_kernel_cascaded_attention, ( + "attn_block_tkg_nki_kernel_enabled requires attn_block_tkg_nki_kernel_cascaded_attention" + ) + self.neuron_config.pre_rope_rmsnorm = True + args += " --internal-max-instruction-limit=15000000" + + return args + + @classmethod + def get_state_dict(cls, model_name_or_path: str, config: InferenceConfig) -> dict: + """Load and convert state dict from a HuggingFace safetensors checkpoint.""" + import json + import os + + from safetensors import safe_open + + if os.path.isdir(model_name_or_path): + index_path = os.path.join( + model_name_or_path, "model.safetensors.index.json" + ) + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + + model_sd: Dict[str, Any] = {} + shard_files = sorted(set(index["weight_map"].values())) + for i, shard_file in enumerate(shard_files): + if i % 20 == 0: + print( + f" Loading shard {i + 1}/{len(shard_files)}: {shard_file}" + ) + shard_path = os.path.join(model_name_or_path, shard_file) + with safe_open(shard_path, framework="pt", device="cpu") as f: + for key in f.keys(): + model_sd[key] = f.get_tensor(key) + + print( + f" Loaded {len(model_sd)} parameters from {len(shard_files)} shards" + ) + + # Strip model. prefix + for param_name in list(model_sd.keys()): + if param_name.startswith(cls._STATE_DICT_MODEL_PREFIX): + new_name = param_name.replace( + cls._STATE_DICT_MODEL_PREFIX, + cls._NEW_STATE_DICT_MODEL_PREFIX, + 1, + ) + model_sd[new_name] = model_sd.pop(param_name) + + model_sd = cls.convert_hf_to_neuron_state_dict(model_sd, config) + + if getattr(config, "tie_word_embeddings", False): + cls.update_state_dict_for_tied_weights(model_sd) + + if cls._FUSED_PREFIX: + for param_name in list(model_sd.keys()): + model_sd[f"{cls._FUSED_PREFIX}.{param_name}"] = model_sd.pop( + param_name + ) + + return model_sd + else: + from neuronx_distributed_inference.modules.checkpoint import ( + load_state_dict, + ) + + return load_state_dict(model_name_or_path) + else: + return super().get_state_dict(model_name_or_path, config) diff --git a/src/neuronx_distributed_inference/utils/constants.py b/src/neuronx_distributed_inference/utils/constants.py index effca933..83186e50 100644 --- a/src/neuronx_distributed_inference/utils/constants.py +++ b/src/neuronx_distributed_inference/utils/constants.py @@ -16,6 +16,8 @@ from neuronx_distributed_inference.models.pixtral.modeling_pixtral import NeuronPixtralForCausalLM from neuronx_distributed_inference.models.pixtral.modeling_pixtral_vision import NeuronPixtralForImageEncoding from neuronx_distributed_inference.models.gemma3.modeling_gemma3 import NeuronGemma3ForCausalLM +from neuronx_distributed_inference.models.mimo_v2.modeling_mimo_v2 import NeuronMiMoV2ForCausalLM +from neuronx_distributed_inference.models.minimax_m2.modeling_minimax_m2 import NeuronMiniMaxM2ForCausalLM END_TO_END_MODEL = "e2e_model" CONTEXT_ENCODING_MODEL = "context_encoding_model" @@ -61,6 +63,8 @@ "qwen3": {"causal-lm": NeuronQwen3ForCausalLM}, "qwen3_moe": {"causal-lm": NeuronQwen3MoeForCausalLM}, "gemma3": {"causal-lm": NeuronGemma3ForCausalLM}, + "mimo_v2": {"causal-lm": NeuronMiMoV2ForCausalLM}, + "minimax_m2": {"causal-lm": NeuronMiniMaxM2ForCausalLM}, "qwen3_vl": {"causal-lm": NeuronQwen3VLForCausalLM, "image-encoding": NeuronQwen3VLForImageEncoding}, }