Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/atom-vllm-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -703,4 +703,3 @@ jobs:
docker rmi "${tag}" || true
done
fi

2 changes: 0 additions & 2 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def rope_cache(self, q, k, v, qkv, position, fwd_ctx: ForwardContext):
elif use_triton_attn and self.rotary_emb is not None:
self.per_token_quant = False
k_scale = v_scale = self.kv_scale

q, k, k_cache, v_cache = fused_qk_rope_reshape_and_cache(
q,
k,
Expand Down Expand Up @@ -477,7 +476,6 @@ def prefill_attention(
window_size=sliding_window,
sink_ptr=self.sinks,
)

return o

def prefill_attention_triton(
Expand Down
12 changes: 11 additions & 1 deletion atom/models/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
maybe_prefix,
)
from atom.utils import envs
from atom.plugin.prepare import is_vllm
from atom.utils.decorators import support_torch_compile
from torch import nn
from transformers import PretrainedConfig
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
tp_size = self.tp_size
self.layer_num = layer_num

self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
Expand Down Expand Up @@ -202,6 +204,7 @@ def __init__(
layer_num=layer_num,
use_mla=False,
rotary_emb=self.rotary_emb,
prefix=f"{prefix}.attn",
)

@staticmethod
Expand Down Expand Up @@ -244,8 +247,14 @@ def forward(
k = (k * torch.rsqrt(k_var + self.rms_norm_eps) * self.k_norm.weight).to(
orig_dtype
)

# TODO: is_vllm will be removed after vllm plugin supporting q,k,v rather than qkv
if is_vllm():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this part of code and avoid adding too much vllm/sglang related codes in the ATOM core files?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, if is_vllm() will be deleted after vllm-atom using q,k,v rather than qkv.

qkv = torch.cat([q, k, v], dim=-1)
attn_output = self.attn(
query=q, key=k, value=v, positions=positions, q_scale=None, qkv=qkv
)

attn_output = self.attn(q, k, v, positions)
output = self.o_proj(attn_output)
return output

Expand Down Expand Up @@ -322,6 +331,7 @@ def forward(
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)

hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

hidden_states = self.block_sparse_moe(hidden_states)

return hidden_states, residual
Expand Down
14 changes: 12 additions & 2 deletions atom/plugin/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def rope_cache_plugin_mode(
attn_metadata = attention_metadata

use_triton_attn = self.sliding_window != -1 or self.head_dim != 128
# use_triton_attn = True
self.use_triton_attn = use_triton_attn

if (
Expand Down Expand Up @@ -162,10 +163,20 @@ def rope_cache_plugin_mode(
output_zeros=False,
)
else:
# for asm paged attention
asm_layout = True
if use_triton_attn:
asm_layout = False
if self.rotary_emb is not None:
assert position is not None
q, k = self.rotary_emb(position, q, k)
if self.q_norm is not None:
q = self.q_norm(q)
if self.k_norm is not None:
k = self.k_norm(k)
new_value_cache = new_value_cache.view(
num_blocks, num_kv_heads, head_size, block_size
)
if self.kv_cache_dtype == "fp8":
aiter.reshape_and_cache_with_pertoken_quant(
k,
Expand All @@ -175,7 +186,7 @@ def rope_cache_plugin_mode(
k_scale,
v_scale,
attn_metadata.slot_mapping,
asm_layout=True,
asm_layout=asm_layout,
)
else:
aiter.reshape_and_cache(
Expand Down Expand Up @@ -590,7 +601,6 @@ def forward_impl_plugin_mode(
if value is not None:
value = value[:num_actual_tokens]
output_actual_tokens = output[:num_actual_tokens]

# rope and cache flush fusion. ATOM always use shuffle layout for kv cache
result = self.rope_cache_plugin_mode(
q=query,
Expand Down
1 change: 1 addition & 0 deletions atom/plugin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _generate_atom_config_from_vllm_config(config: Any) -> PluginConfig:

return Config(
model=vllm_model_config.model,
trust_remote_code=getattr(vllm_model_config, "trust_remote_code", False),
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=vllm_scheduler_config.max_num_seqs,
max_model_len=max_model_len,
Expand Down
2 changes: 2 additions & 0 deletions atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from atom.models.qwen3_moe import Qwen3MoeForCausalLM
from atom.models.glm4_moe import Glm4MoeForCausalLM
from atom.models.deepseek_v2 import DeepseekV3ForCausalLM
from atom.models.minimax_m2 import MiniMaxM2ForCausalLM
from atom.config import Config
from atom.plugin.prepare import is_vllm, is_sglang

Expand All @@ -14,6 +15,7 @@
"Qwen3MoeForCausalLM": Qwen3MoeForCausalLM,
"Glm4MoeForCausalLM": Glm4MoeForCausalLM,
"DeepseekV3ForCausalLM": DeepseekV3ForCausalLM,
"MiniMaxM2ForCausalLM": MiniMaxM2ForCausalLM,
}


Expand Down
1 change: 1 addition & 0 deletions atom/plugin/vllm/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration_",
"Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration_",
"KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration_",
"MiniMaxM2ForCausalLM": "atom.models.minimax_m2:MiniMaxM2ForCausalLM",
}


Expand Down
1 change: 1 addition & 0 deletions atom/plugin/vllm/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"Qwen3_5ForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5:Qwen3_5MoeForConditionalGeneration",
"KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration",
"MiniMaxM2ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER,
}


Expand Down
78 changes: 78 additions & 0 deletions recipes/atom_vllm/MiniMax-M2.5.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# MiniMax-M2.5 with ATOM vLLM Plugin Backend

This recipe shows how to run `MiniMaxAI/MiniMax-M2.5` (HF `architectures[0]`: `MiniMaxM2ForCausalLM`, MoE + FP8 weights) with the ATOM vLLM plugin backend. For background on the plugin backend, see [ATOM vLLM Plugin Backend](../../docs/vllm_plugin_backend_guide.md).

The checkpoint uses custom modeling code; keep `--trust-remote-code` on the server command line.

## Step 1: Pull the OOT Docker

```bash
docker pull rocm/atom-dev:vllm-latest
```

## Step 2: Launch vLLM Server

The ATOM vLLM plugin backend keeps the standard vLLM CLI, server APIs, and general usage flow compatible with upstream vLLM. For general server options and API usage, refer to the [official vLLM documentation](https://docs.vllm.ai/en/latest/).

The following matches the internal benchmark entry (`--kv_cache_dtype fp8 -tp 2 --trust-remote-code` in `.github/benchmark/models.json`). On multi-GPU hosts, use tensor parallel size 2 or adjust to your topology.

```bash
vllm serve MiniMaxAI/MiniMax-M2.5 \
--host localhost \
--port 8000 \
--async-scheduling \
--tensor-parallel-size 2 \
--trust-remote-code \
--gpu_memory_utilization 0.9 \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--kv-cache-dtype fp8 \
--no-enable-prefix-caching \
--enforce-eager
```

Caveat: the upstream `config.json` may advertise MTP-related fields; the current ATOM `MiniMaxM2ForCausalLM` path targets the main transformer. If you hit load or shape errors around MTP modules, compare with native ATOM server behavior and upstream vLLM release notes.

## Step 3: Performance Benchmark

```bash
vllm bench serve \
--host localhost \
--port 8000 \
--model MiniMaxAI/MiniMax-M2.5 \
--dataset-name random \
--random-input-len 8000 \
--random-output-len 1000 \
--random-range-ratio 0.8 \
--max-concurrency 64 \
--num-prompts 640 \
--trust_remote_code \
--percentile-metrics ttft,tpot,itl,e2el
```

## Step 4: Accuracy Validation

Nightly OOT accuracy uses `gsm8k` with **3-shot** in `.github/scripts/atom_oot_test.sh` (same as other full-validation models). For a local check:

```bash
lm_eval --model local-completions \
--model_args model=MiniMaxAI/MiniMax-M2.5,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,trust_remote_code=True \
--tasks gsm8k \
--num_fewshot 3 \
--output_path ./lm_eval_minimax_m25_gsm8k
```

Reference metric (tracking baseline for this model family; replace with your run output and keep the raw JSON path next to the table):

- Internal tracking: `accuracy_baseline` **0.9401** for `MiniMaxAI/MiniMax-M2.5` in `.github/benchmark/models_accuracy.json` (see `_baseline_note` there for HF card context).
- OOT gate: `accuracy_test_threshold` **0.92** on `exact_match,flexible-extract` (see `atom-vllm-oot-test.yaml` nightly matrix).

Example table shape after `lm_eval` (fill `Value` / `Stderr` from your console or `${output_path}` JSON):

```text
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.9287|± |0.0071|
| | |strict-match | 3|exact_match|↑ |0.9272|± |0.0072|
```

Raw results JSON: `<path-to-lm_eval-output-*.json>` (from `--output_path` above).
Loading