From 46ea76fdedbd0550c4e769fb71a73ba4d646e018 Mon Sep 17 00:00:00 2001 From: perzhang Date: Sun, 12 Apr 2026 04:00:11 +0000 Subject: [PATCH] [feat](minimax): support minimax-2.5 in atom-vllm mode --- .github/workflows/atom-vllm-test.yaml | 1 - atom/model_ops/attention_mha.py | 2 - atom/models/minimax_m2.py | 12 ++++- atom/plugin/attention_mha.py | 14 ++++- atom/plugin/config.py | 1 + atom/plugin/register.py | 2 + atom/plugin/vllm/model_wrapper.py | 1 + atom/plugin/vllm/register.py | 1 + recipes/atom_vllm/MiniMax-M2.5.md | 78 +++++++++++++++++++++++++++ 9 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 recipes/atom_vllm/MiniMax-M2.5.md diff --git a/.github/workflows/atom-vllm-test.yaml b/.github/workflows/atom-vllm-test.yaml index 7ad6ce485..b8424898b 100644 --- a/.github/workflows/atom-vllm-test.yaml +++ b/.github/workflows/atom-vllm-test.yaml @@ -703,4 +703,3 @@ jobs: docker rmi "${tag}" || true done fi - diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 23a1bed01..6a4cb2067 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -172,7 +172,6 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext): elif use_triton_attn and self.rotary_emb is not None: self.per_token_quant = False k_scale = v_scale = self.kv_scale - q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache( q, k, @@ -477,7 +476,6 @@ def prefill_attention( window_size=sliding_window, sink_ptr=self.sinks, ) - return o def prefill_attention_triton( diff --git a/atom/models/minimax_m2.py b/atom/models/minimax_m2.py index f25712b01..0c99da7bf 100644 --- a/atom/models/minimax_m2.py +++ b/atom/models/minimax_m2.py @@ -22,6 +22,7 @@ maybe_prefix, ) from atom.utils import envs +from atom.plugin.prepare import is_vllm from atom.utils.decorators import support_torch_compile from torch import nn from transformers import PretrainedConfig @@ -132,6 +133,7 @@ def __init__( super().__init__() self.tp_size = get_tensor_model_parallel_world_size() tp_size = self.tp_size + self.layer_num = layer_num self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -202,6 +204,7 @@ def __init__( layer_num=layer_num, use_mla=False, rotary_emb=self.rotary_emb, + prefix=f"{prefix}.attn", ) @staticmethod @@ -244,8 +247,14 @@ def forward( k = (k * torch.rsqrt(k_var + self.rms_norm_eps) * self.k_norm.weight).to( orig_dtype ) + + # TODO: is_vllm will be removed after vllm plugin supporting q,k,v rather than qkv + if is_vllm(): + qkv = torch.cat([q, k, v], dim=-1) + attn_output = self.attn( + query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv + ) - attn_output = self.attn(q, k, v, positions) output = self.o_proj(attn_output) return output @@ -322,6 +331,7 @@ def forward( hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) return hidden_states, residual diff --git a/atom/plugin/attention_mha.py b/atom/plugin/attention_mha.py index f2121ac0a..9f53cecf8 100644 --- a/atom/plugin/attention_mha.py +++ b/atom/plugin/attention_mha.py @@ -100,6 +100,7 @@ def rope_cache_plugin_mode( attn_metadata = attention_metadata use_triton_attn = self.sliding_window != -1 or self.head_dim != 128 + # use_triton_attn = True self.use_triton_attn = use_triton_attn if ( @@ -162,10 +163,20 @@ def rope_cache_plugin_mode( output_zeros=False, ) else: + # for asm paged attention + asm_layout = True + if use_triton_attn: + asm_layout = False + if self.rotary_emb is not None: + assert position is not None + q, k = self.rotary_emb(position, q, k) if self.q_norm is not None: q = self.q_norm(q) if self.k_norm is not None: k = self.k_norm(k) + new_value_cache = new_value_cache.view( + num_blocks, num_kv_heads, head_size, block_size + ) if self.kv_cache_dtype == "fp8": aiter.reshape_and_cache_with_pertoken_quant( k, @@ -175,7 +186,7 @@ def rope_cache_plugin_mode( k_scale, v_scale, attn_metadata.slot_mapping, - asm_layout=True, + asm_layout=asm_layout, ) else: aiter.reshape_and_cache( @@ -590,7 +601,6 @@ def forward_impl_plugin_mode( if value is not None: value = value[:num_actual_tokens] output_actual_tokens = output[:num_actual_tokens] - # rope and cache flush fusion. ATOM always use shuffle layout for kv cache result = self.rope_cache_plugin_mode( q=query, diff --git a/atom/plugin/config.py b/atom/plugin/config.py index 82aca61e2..d75161c04 100644 --- a/atom/plugin/config.py +++ b/atom/plugin/config.py @@ -83,6 +83,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig: return Config( model=vllm_model_config.model, + trust_remote_code=getattr(vllm_model_config, "trust_remote_code", False), max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=vllm_scheduler_config.max_num_seqs, max_model_len=max_model_len, diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 16853d6ef..8042055e7 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -4,6 +4,7 @@ from atom.models.qwen3_moe import Qwen3MoeForCausalLM from atom.models.glm4_moe import Glm4MoeForCausalLM from atom.models.deepseek_v2 import DeepseekV3ForCausalLM +from atom.models.minimax_m2 import MiniMaxM2ForCausalLM from atom.config import Config from atom.plugin.prepare import is_vllm, is_sglang @@ -14,6 +15,7 @@ "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, "Glm4MoeForCausalLM": Glm4MoeForCausalLM, "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, + "MiniMaxM2ForCausalLM": MiniMaxM2ForCausalLM, } diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 17112d103..0d38688b4 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -45,6 +45,7 @@ "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration_", "Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration_", "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration_", + "MiniMaxM2ForCausalLM": "atom.models.minimax_m2:MiniMaxM2ForCausalLM", } diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index 0bc9d7416..9a0468f18 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -30,6 +30,7 @@ "Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration", "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration", + "MiniMaxM2ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } diff --git a/recipes/atom_vllm/MiniMax-M2.5.md b/recipes/atom_vllm/MiniMax-M2.5.md new file mode 100644 index 000000000..cbaf1d4b7 --- /dev/null +++ b/recipes/atom_vllm/MiniMax-M2.5.md @@ -0,0 +1,78 @@ +# MiniMax-M2.5 with ATOM vLLM Plugin Backend + +This recipe shows how to run `MiniMaxAI/MiniMax-M2.5` (HF `architectures[0]`: `MiniMaxM2ForCausalLM`, MoE + FP8 weights) with the ATOM vLLM plugin backend. For background on the plugin backend, see [ATOM vLLM Plugin Backend](../../docs/vllm_plugin_backend_guide.md). + +The checkpoint uses custom modeling code; keep `--trust-remote-code` on the server command line. + +## Step 1: Pull the OOT Docker + +```bash +docker pull rocm/atom-dev:vllm-latest +``` + +## Step 2: Launch vLLM Server + +The ATOM vLLM plugin backend keeps the standard vLLM CLI, server APIs, and general usage flow compatible with upstream vLLM. For general server options and API usage, refer to the [official vLLM documentation](https://docs.vllm.ai/en/latest/). + +The following matches the internal benchmark entry (`--kv_cache_dtype fp8 -tp 2 --trust-remote-code` in `.github/benchmark/models.json`). On multi-GPU hosts, use tensor parallel size 2 or adjust to your topology. + +```bash +vllm serve MiniMaxAI/MiniMax-M2.5 \ + --host localhost \ + --port 8000 \ + --async-scheduling \ + --tensor-parallel-size 2 \ + --trust-remote-code \ + --gpu_memory_utilization 0.9 \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ + --kv-cache-dtype fp8 \ + --no-enable-prefix-caching \ + --enforce-eager +``` + +Caveat: the upstream `config.json` may advertise MTP-related fields; the current ATOM `MiniMaxM2ForCausalLM` path targets the main transformer. If you hit load or shape errors around MTP modules, compare with native ATOM server behavior and upstream vLLM release notes. + +## Step 3: Performance Benchmark + +```bash +vllm bench serve \ + --host localhost \ + --port 8000 \ + --model MiniMaxAI/MiniMax-M2.5 \ + --dataset-name random \ + --random-input-len 8000 \ + --random-output-len 1000 \ + --random-range-ratio 0.8 \ + --max-concurrency 64 \ + --num-prompts 640 \ + --trust_remote_code \ + --percentile-metrics ttft,tpot,itl,e2el +``` + +## Step 4: Accuracy Validation + +Nightly OOT accuracy uses `gsm8k` with **3-shot** in `.github/scripts/atom_oot_test.sh` (same as other full-validation models). For a local check: + +```bash +lm_eval --model local-completions \ + --model_args model=MiniMaxAI/MiniMax-M2.5,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,trust_remote_code=True \ + --tasks gsm8k \ + --num_fewshot 3 \ + --output_path ./lm_eval_minimax_m25_gsm8k +``` + +Reference metric (tracking baseline for this model family; replace with your run output and keep the raw JSON path next to the table): + +- Internal tracking: `accuracy_baseline` **0.9401** for `MiniMaxAI/MiniMax-M2.5` in `.github/benchmark/models_accuracy.json` (see `_baseline_note` there for HF card context). +- OOT gate: `accuracy_test_threshold` **0.92** on `exact_match,flexible-extract` (see `atom-vllm-oot-test.yaml` nightly matrix). + +Example table shape after `lm_eval` (fill `Value` / `Stderr` from your console or `${output_path}` JSON): + +```text +|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| +|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| +|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.9287|± |0.0071| +| | |strict-match | 3|exact_match|↑ |0.9272|± |0.0072| +``` + +Raw results JSON: `` (from `--output_path` above).