Skip to content

feat: add Gemma4 31B support for standalone and vLLM plugin mode on MI355X#546

Open
ClementLinCF wants to merge 18 commits intomainfrom
gemma4-dev
Open

feat: add Gemma4 31B support for standalone and vLLM plugin mode on MI355X#546
ClementLinCF wants to merge 18 commits intomainfrom
gemma4-dev

Conversation

@ClementLinCF
Copy link
Copy Markdown

@ClementLinCF ClementLinCF commented Apr 12, 2026

Motivation

Add Gemma 4 31B-it support for ATOM on AMD Instinct MI355X (gfx950 / CDNA4), covering both Config A (ATOM standalone server) and Config B (vLLM + ATOM OOT plugin).
Gemma 4 introduces architectural features not present in other supported models: heterogeneous attention layers with different head_dim (sliding=128, global=256), per-layer-type RoPE configurations, and head_dim > 256 which exceeds the CK Flash Attention limit. This PR handles all of these while maintaining backward compatibility with existing models (LLaMA, DeepSeek, Kimi-K2.5, Qwen, etc.).

Technical Details

New: Gemma 4 model implementation

  • atom/models/gemma4.py (+620 lines): Full inference-only Gemma 4 text backbone with AITER-optimized operators — AITER RMSNorm, ProportionalRotaryEmbedding (from aiter), Gluon Paged Attention, fused logit softcapping, CUDA graph support via @support_torch_compile.

Heterogeneous head_dim support

  • atom/model_engine/model_runner.py: New _get_per_layer_kv_dims() method returns per-layer (num_kv_heads, head_dim) tuples for models with mixed sliding/global attention. KV cache allocation, block size computation, and cache binding all branch on this to handle layers with different dimensions. Original homogeneous path preserved in else branches.

CK Flash Attention head_dim > 256 fallback

  • atom/model_ops/attention_mha.py: Added _sdpa_varlen_fallback() — PyTorch SDPA fallback with GQA head expansion for prefill when head_dim > 256 (CK limitation). Original CK path unchanged for head_dim <= 256.

  • atom/plugin/attention_mha.py: Added _flash_attn_varlen_with_fallback() and _triton_flash_attn_varlen() for the plugin extend/prefill paths. Automatic dispatch based on head_dim. Also added qkv is not None null checks for Gemma 4 code paths.

Plugin mode (vLLM) support

  • atom/plugin/attention.py: Added swa_head_dim tracking for heterogeneous sliding window attention workspace allocation. Rebased plugin onto vLLM v1 AttentionBackend for API compatibility.

  • atom/plugin/vllm/register.py and atom/plugin/vllm/model_wrapper.py: Registered Gemma4ForConditionalGeneration in both model registries (+1 line each).

Multimodal config support

  • atom/config.py: Added "gemma4": "text_config" to _MULTIMODAL_MODEL_TYPES and get_hf_text_config() extraction in Config.init so engine code can access num_hidden_layers, head_dim, rope_theta, etc. directly. Removed _resolve_atom_text_config() fallback now that HF transformers natively supports Gemma 4.

Dependency

  • aiter gemma4-dev: ProportionalRotaryEmbedding for Gemma 4 RoPE, rmsnorm2d_fwd weight dtype check fix.
  • HF transformers >= 5.5.3: Native Gemma4Config / Gemma4TextConfig support.

Test Plan

ATOM Plugin:
vllm serve "$MODEL"
--host 0.0.0.0 --port "$PORT"
--tensor-parallel-size "$TP"
--trust-remote-code
--max-model-len 4096
> /tmp/vllm_plugin_server.log 2>&1 &

ATOM standalone:
python -m atom.entrypoints.openai_server
--model "$MODEL"
--tensor-parallel-size "$TP"
--max-model-len 4096
$EAGER_FLAG
> /tmp/atom_server.log 2>&1 &

Test Result

Concurrency = 2 (low concurrency)

Metric Config A (Standalone) Config B (vLLM Plugin)
Output tok/s 169.10 136.97
Mean TTFT 58.89 ms 178.64 ms
Median TTFT 53.57 ms 89.83 ms
P99 TTFT 98.28 ms 1537.27 ms
Mean TPOT 11.46 ms 13.30 ms
Median TPOT 11.46 ms 13.02 ms
P99 TPOT 11.61 ms 19.53 ms
Mean E2EL 1513.80 ms 1867.69 ms
Failed 0 0

Concurrency = 40 (high concurrency)

Metric Config A (Standalone) Config B (vLLM Plugin)
Output tok/s 2018.86 1681.18
Mean TTFT 645.75 ms 1066.13 ms
Median TTFT 705.84 ms 304.08 ms
P99 TTFT 727.92 ms 2045.75 ms
Mean TPOT 14.85 ms 15.50 ms
Median TPOT 14.29 ms 15.46 ms
P99 TPOT 17.53 ms 29.84 ms
Mean E2EL 2532.09 ms 3035.50 ms
Failed 0 0

Submission Checklist

ClementLinCF and others added 16 commits April 7, 2026 22:33
- Rebase AiterBackend onto vLLM v1 AttentionBackend so plugin passes
  the v1 attention selector's method checks
- Add Triton FA2 / PyTorch SDPA fallback for prefill when head_dim > 256
  (CK flash_attn_varlen_func limitation)
- Guard qkv and position slicing against None values
Gemma 4 uses head_dim=256 for sliding window layers and head_dim=512
for full attention layers. Two issues fixed:

1. swa_workspace was allocated with model-global head_dim (512) from
   get_head_size(), causing assertion failure in cp_mha_gather_cache
   when the KV cache for sliding window layers has head_dim=256.
   Fix: derive swa_head_dim from actual sliding window layer impl.

2. CK flash attention kernel hard-limits head_dim to 256, crashing
   extend_forward for full attention layers under high concurrency.
   Fix: fallback to torch SDPA with GQA head repeating + LSE
   computation when head_dim > 256.
Enable Gemma 4 31B-IT to run in ATOM standalone (Config A) on MI355X
with TP=8. This required several model-specific adaptations:

- New model files: atom/models/gemma4.py (Attention, MLP, DecoderLayer,
  Model, ForCausalLM) and atom/model_config/gemma4.py (Gemma4TextConfig)
- RMSNorm: use standard x*weight formula (not Gemma1/2 x*(1+weight))
- v_norm: add missing value normalization (RMSNorm without scale)
- k_eq_v: apply v_norm to raw k before k_norm for global attention
- layer_scalar: apply once at end of layer (not twice to sub-blocks)
- Attention scaling: use 1.0 (not 1/sqrt(head_dim)) since q/k norms
  already control magnitudes
- Config: resolve gemma4 multimodal config and gemma4_text sub-config
- KV cache: per-layer allocation for heterogeneous head_dim (512 global
  vs 256 sliding)
- Prefill attention: SDPA fallback for head_dim > 256 (CK limitation)
_Gemma4RMSNorm now dispatches to aiter.rmsnorm2d_fwd when with_scale=True,
falling back to pure PyTorch for with_scale=False (v_norm). This is enabled
by the upstream aiter rmsnorm dtype fix (5df37c1) which auto-casts weight
to input dtype, preventing the silent data corruption that previously
blocked AITER kernel usage.

With this change, --enforce-eager is no longer required — CUDA graph
capture works correctly and provides ~2x decode throughput improvement.
- Extract _sdpa_varlen_attn as single shared SDPA varlen fallback
  (replaces 3 near-identical copies across plugin and model_ops)
- Log warning when Triton FA2 fails and falls back to SDPA
- Log warning when vLLM v1 AttentionBackend import fails in
  AiterBackendDecoratorForPluginMode
- Clean up commit hash reference in _Gemma4RMSNorm docstring
Set accept_output_buffer=True in vllmAiterAttentionBackendMethods so
vLLM pre-allocates the output buffer at a fixed address for CUDA graph
replay. Pass the buffer through forward() → forward_impl_plugin_mode()
and use it when available instead of torch.empty().

Root cause: PIECEWISE CUDA graph pieces read output from the address
captured during warmup, but ATOM allocated a new tensor each call →
stale data → garbage output.

Three files changed:
- atom/plugin/attention.py: accept_output_buffer = True
- atom/model_ops/attention_mha.py: pass output= to plugin forward
- atom/plugin/attention_mha.py: conditional output buffer allocation
…upport

HF transformers >= 5.5.3 natively supports Gemma4Config/Gemma4TextConfig.
Remove the ATOM-side fallback config and use HF directly.

- Delete atom/model_config/gemma4.py (226 lines, no longer needed)
- Remove _resolve_atom_text_config() try/except fallback in config.py
- Update atom/models/gemma4.py to import from transformers.models.gemma4
Re-add comments that were inadvertently removed during the Gemma 4
heterogeneous attention support (bf992a1), and add a new comment
explaining the per-layer branch for clarity.
Planning/documentation file that is not imported by any runtime code.

Made-with: Cursor
Incomplete experimental kernel, not imported by any runtime code.

Made-with: Cursor
Copilot AI review requested due to automatic review settings April 12, 2026 12:30
from aiter.rotary_embedding import get_rope

from atom.config import Config, QuantizationConfig
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
transformers.models.gemma4.configuration_gemma4.Gemma4Config imported but unused

Suggested change
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig

Comment on lines +123 to +124
import torch.nn.functional as F

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
torch.nn.functional imported but unused

Suggested change
import torch.nn.functional as F

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Gemma 4 31B-it support to ATOM on MI355X across both standalone server mode and vLLM plugin mode, including required handling for Gemma4’s heterogeneous attention head dimensions and FlashAttention head-dim limits.

Changes:

  • Introduces an inference-only Gemma4 text backbone implementation and registers it for vLLM plugin usage.
  • Extends KV-cache sizing/allocation/binding to support per-layer (num_kv_heads, head_dim) heterogeneity.
  • Adds attention fallbacks for head_dim > 256 (CK limitation) in both standalone and plugin attention paths, plus multimodal text-config extraction.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
atom/plugin/vllm/register.py Registers Gemma4 arch to use the ATOM causal LM wrapper in vLLM plugin mode.
atom/plugin/vllm/model_wrapper.py Maps Gemma4 HF architecture name to the ATOM Gemma4 implementation class.
atom/plugin/attention.py Tracks SWA head-dim for workspace sizing; attempts to rebase backend class onto vLLM v1 AttentionBackend.
atom/plugin/attention_mha.py Adds CK→(Triton/SDPA) fallback paths for varlen attention when head_dim > 256, plus some null-guarding.
atom/models/gemma4.py New Gemma4 inference-only model implementation with AITER-optimized components.
atom/model_ops/attention_mha.py Adds SDPA fallback for standalone prefill varlen attention when head_dim > 256.
atom/model_engine/model_runner.py Adds per-layer KV-dim discovery and uses it for heterogeneous KV cache allocation/binding and memory estimation.
atom/config.py Resolves multimodal configs to text_config early so engine code can directly access text attributes.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1388 to +1400
# Rebase onto vLLM v1 AttentionBackend so cls inherits default method
# stubs (get_preferred_block_size, validate_configuration, supports_*,
# etc.) that the vLLM v1 attention selector requires.
try:
from vllm.v1.attention.backend import (
AttentionBackend as _VllmAttnBackend,
)
cls.__bases__ = (_VllmAttnBackend,)
except ImportError:
logger.warning(
"vllm.v1.attention.backend not found; %s will not inherit "
"vLLM v1 AttentionBackend stubs — attention selector may "
"reject this backend.", cls.__name__)
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls.__bases__ = (_VllmAttnBackend,) can raise TypeError (e.g., incompatible layout / metaclass) but the code only catches ImportError, which would crash plugin initialization at import time. Consider avoiding runtime base-class mutation (e.g., create/return a new subclass with the desired bases) or at least catch Exception/TypeError and fall back gracefully.

Suggested change
# Rebase onto vLLM v1 AttentionBackend so cls inherits default method
# stubs (get_preferred_block_size, validate_configuration, supports_*,
# etc.) that the vLLM v1 attention selector requires.
try:
from vllm.v1.attention.backend import (
AttentionBackend as _VllmAttnBackend,
)
cls.__bases__ = (_VllmAttnBackend,)
except ImportError:
logger.warning(
"vllm.v1.attention.backend not found; %s will not inherit "
"vLLM v1 AttentionBackend stubs — attention selector may "
"reject this backend.", cls.__name__)
# Create a subclass that also inherits the vLLM v1 AttentionBackend
# so cls gets the default method stubs (get_preferred_block_size,
# validate_configuration, supports_*, etc.) required by the vLLM v1
# attention selector. Avoid mutating __bases__ at runtime because it
# can raise TypeError for incompatible layouts/metaclasses.
try:
from vllm.v1.attention.backend import (
AttentionBackend as _VllmAttnBackend,
)
if not issubclass(cls, _VllmAttnBackend):
cls = type(
cls.__name__,
(cls, _VllmAttnBackend),
{
"__module__": cls.__module__,
"__qualname__": cls.__qualname__,
"__doc__": cls.__doc__,
},
)
except ImportError:
logger.warning(
"vllm.v1.attention.backend not found; %s will not inherit "
"vLLM v1 AttentionBackend stubs — attention selector may "
"reject this backend.", cls.__name__)
except TypeError as exc:
logger.warning(
"Unable to create vLLM AttentionBackend-compatible subclass "
"for %s; keeping original class without vLLM v1 stubs: %s",
cls.__name__,
exc,
)

Copilot uses AI. Check for mistakes.
Comment on lines +961 to +964
ws = self.world_size
swa_kv = max(hf_config.num_key_value_heads // ws, 1)
glo_kv = max(num_global_kv_heads // ws, 1)
dims = []
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_per_layer_kv_dims() computes per-rank KV heads via max(num_*_kv_heads // ws, 1) without the divisibility checks used in _get_num_kv_heads(). For unsupported TP sizes this can silently under-allocate KV cache and fail later in harder-to-debug ways; consider mirroring the same >= ws / % ws == 0 assertions for both sliding and global KV heads.

Copilot uses AI. Check for mistakes.
ki = k[sk_s:sk_e].transpose(0, 1).unsqueeze(0)
vi = v[sk_s:sk_e].transpose(0, 1).unsqueeze(0)
num_q_heads, num_kv_heads = qi.shape[1], ki.shape[1]
if num_q_heads != num_kv_heads:
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the GQA expansion path, rep = num_q_heads // num_kv_heads is used without asserting num_q_heads % num_kv_heads == 0. If they’re not divisible this will either produce incorrect head replication or trigger shape errors in SDPA; add an explicit divisibility assertion (or handle the remainder) before repeating heads.

Suggested change
if num_q_heads != num_kv_heads:
if num_q_heads != num_kv_heads:
if num_q_heads % num_kv_heads != 0:
raise ValueError(
f"GQA head expansion requires num_q_heads to be divisible "
f"by num_kv_heads, got num_q_heads={num_q_heads} and "
f"num_kv_heads={num_kv_heads}"
)

Copilot uses AI. Check for mistakes.
1. Quark exclude list name remapping: Quark uses `model.language_model.*`
   naming but ATOM uses `model.*`. Added `quant_exclude_name_mapping` to
   Gemma4ForCausalLM and passed it through model_wrapper.py so self_attn
   layers are correctly identified as excluded (BF16).

2. ASM FP4 GEMM K-dimension alignment: The ASM kernel
   `f4gemm_bf16_per1x32Fp4_BpreShuffle` requires K_packed to be a multiple
   of 128. At TP=8, down_proj has K_packed=1344 (not aligned). Added weight
   K padding in process_weights_after_loading and runtime input K padding
   in forward to align to 128.

3. Triton preshuffle M-dimension padding: When M < 32, pad input x and
   x_scale to prevent out-of-bounds reads in the preshuffle GEMM path.
k_pad = self._fp4_k_padded
elem_size = x.element_size()
fp4_elem_size = 1
x_logical_k = x.shape[-1] * elem_size // fp4_elem_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable x_logical_k is assigned to but never used

Suggested change
x_logical_k = x.shape[-1] * elem_size // fp4_elem_size
x.shape[-1] * elem_size // fp4_elem_size

Comment on lines +10 to +11
import torch.nn.functional as F
from torch import nn
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
torch.nn.functional imported but unused

Suggested change
import torch.nn.functional as F
from torch import nn
from torch import nn

Use aiter's rmsnorm2d_fwd_with_add to fuse the residual addition and
RMSNorm into a single kernel, eliminating one memory round-trip per
fusion point. Two fusion points per layer (input_layernorm and
pre_feedforward_layernorm) × 60 layers = 120 fewer kernel launches.
Copilot AI review requested due to automatic review settings April 13, 2026 14:47
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +167 to +168
xs_padded = torch.zeros(m_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device)
xs_padded[:m] = x_scale
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the gemm_afp4wfp4_preshuffle path, the new padding logic reshapes x_scale to a block-major shape (view(x_scale.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1)) and then pads it as if it were token-major (xs_padded = torch.zeros(m_padded, ...) and xs_padded[:m] = x_scale). When m is not a multiple of MXFP4_QUANT_BLOCK_SIZE, this will raise a shape mismatch at runtime and/or feed incorrectly-shaped scales into the preshuffle kernel. Consider padding in the same representation expected by the kernel (e.g., pad before reshaping, or pad the block dimension to ceil(m/BS)), and avoid mixing token- vs block-shaped scales.

Suggested change
xs_padded = torch.zeros(m_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device)
xs_padded[:m] = x_scale
if m >= MXFP4_QUANT_BLOCK_SIZE:
scale_rows_padded = m_padded // MXFP4_QUANT_BLOCK_SIZE
xs_padded = torch.zeros(
scale_rows_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device
)
xs_padded[: x_scale.shape[0]] = x_scale
else:
xs_padded = torch.zeros(
m_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device
)
xs_padded[:m] = x_scale

Copilot uses AI. Check for mistakes.
k_pad = self._fp4_k_padded
elem_size = x.element_size()
fp4_elem_size = 1
x_logical_k = x.shape[-1] * elem_size // fp4_elem_size
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_logical_k is computed but never used. If it’s not needed for subsequent padding logic, remove it to avoid confusion; if it is needed, incorporate it into the computation (or add an assertion/comment explaining the intent).

Suggested change
x_logical_k = x.shape[-1] * elem_size // fp4_elem_size

Copilot uses AI. Check for mistakes.
Comment on lines +1388 to +1395
# Rebase onto vLLM v1 AttentionBackend so cls inherits default method
# stubs (get_preferred_block_size, validate_configuration, supports_*,
# etc.) that the vLLM v1 attention selector requires.
try:
from vllm.v1.attention.backend import (
AttentionBackend as _VllmAttnBackend,
)
cls.__bases__ = (_VllmAttnBackend,)
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutating cls.__bases__ at runtime is fragile (can raise TypeError if base layouts are incompatible) and it also drops the original ATOM AttentionBackend base class from the MRO. It would be safer to define a small adapter subclass that inherits from vllm.v1.attention.backend.AttentionBackend (or uses multiple inheritance) rather than rebasing an existing class in-place.

Suggested change
# Rebase onto vLLM v1 AttentionBackend so cls inherits default method
# stubs (get_preferred_block_size, validate_configuration, supports_*,
# etc.) that the vLLM v1 attention selector requires.
try:
from vllm.v1.attention.backend import (
AttentionBackend as _VllmAttnBackend,
)
cls.__bases__ = (_VllmAttnBackend,)
# Return a small adapter subclass instead of rebasing cls in-place.
# This preserves the original ATOM backend in the MRO while also
# inheriting the default vLLM v1 AttentionBackend method stubs
# (get_preferred_block_size, validate_configuration, supports_*,
# etc.) required by the vLLM v1 attention selector.
try:
from vllm.v1.attention.backend import (
AttentionBackend as _VllmAttnBackend,
)
if not issubclass(cls, _VllmAttnBackend):
cls = type(
cls.__name__,
(cls, _VllmAttnBackend),
{
"__module__": cls.__module__,
"__doc__": cls.__doc__,
"__qualname__": cls.__qualname__,
},
)

Copilot uses AI. Check for mistakes.
Comment on lines +53 to +76
for i in range(num_seqs):
sq_s, sq_e = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
sk_s, sk_e = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
qi = q[sq_s:sq_e].transpose(0, 1).unsqueeze(0)
ki = k[sk_s:sk_e].transpose(0, 1).unsqueeze(0)
vi = v[sk_s:sk_e].transpose(0, 1).unsqueeze(0)
num_q_heads, num_kv_heads = qi.shape[1], ki.shape[1]
if num_q_heads != num_kv_heads:
rep = num_q_heads // num_kv_heads
ki = ki.repeat_interleave(rep, dim=1)
vi = vi.repeat_interleave(rep, dim=1)
oi = F.scaled_dot_product_attention(
qi, ki, vi, scale=softmax_scale, is_causal=causal,
)
out[sq_s:sq_e] = oi.squeeze(0).transpose(0, 1)

if return_lse:
scores = torch.matmul(qi * softmax_scale, ki.transpose(-1, -2))
if causal and (sq_e - sq_s) == (sk_e - sk_s):
mask = torch.triu(
torch.full_like(scores, float("-inf")), diagonal=1)
scores = scores + mask
lse[:, sq_s:sq_e] = torch.logsumexp(scores, dim=-1).squeeze(0)

Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_sdpa_varlen_attn iterates over sequences in Python and calls scaled_dot_product_attention per-sequence (plus an extra matmul/logsumexp when return_lse=True). For large batches (many sequences) this fallback will be extremely slow and can dominate prefill/extend latency when head_dim > 256. Consider a more vectorized fallback (e.g., pad to max_seqlen and run SDPA in a single call, or use a compiled/kernel varlen implementation when available) to keep the fallback usable in production.

Suggested change
for i in range(num_seqs):
sq_s, sq_e = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
sk_s, sk_e = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
qi = q[sq_s:sq_e].transpose(0, 1).unsqueeze(0)
ki = k[sk_s:sk_e].transpose(0, 1).unsqueeze(0)
vi = v[sk_s:sk_e].transpose(0, 1).unsqueeze(0)
num_q_heads, num_kv_heads = qi.shape[1], ki.shape[1]
if num_q_heads != num_kv_heads:
rep = num_q_heads // num_kv_heads
ki = ki.repeat_interleave(rep, dim=1)
vi = vi.repeat_interleave(rep, dim=1)
oi = F.scaled_dot_product_attention(
qi, ki, vi, scale=softmax_scale, is_causal=causal,
)
out[sq_s:sq_e] = oi.squeeze(0).transpose(0, 1)
if return_lse:
scores = torch.matmul(qi * softmax_scale, ki.transpose(-1, -2))
if causal and (sq_e - sq_s) == (sk_e - sk_s):
mask = torch.triu(
torch.full_like(scores, float("-inf")), diagonal=1)
scores = scores + mask
lse[:, sq_s:sq_e] = torch.logsumexp(scores, dim=-1).squeeze(0)
# Expand KV heads once, then execute a single batched SDPA over padded
# tensors instead of launching one SDPA per sequence from Python.
num_kv_heads = k.shape[1]
if num_heads_q != num_kv_heads:
rep = num_heads_q // num_kv_heads
k = k.repeat_interleave(rep, dim=1)
v = v.repeat_interleave(rep, dim=1)
q_starts = cu_seqlens_q[:-1].tolist()
q_ends = cu_seqlens_q[1:].tolist()
k_starts = cu_seqlens_k[:-1].tolist()
k_ends = cu_seqlens_k[1:].tolist()
q_lens = [q_e - q_s for q_s, q_e in zip(q_starts, q_ends)]
k_lens = [k_e - k_s for k_s, k_e in zip(k_starts, k_ends)]
max_q = max(q_lens) if q_lens else 0
max_k = max(k_lens) if k_lens else 0
if max_q == 0 or max_k == 0:
if return_lse:
lse.zero_()
return out, lse
return out
batch_q = q.new_zeros((num_seqs, num_heads_q, max_q, q.shape[-1]))
batch_k = k.new_zeros((num_seqs, num_heads_q, max_k, k.shape[-1]))
batch_v = v.new_zeros((num_seqs, num_heads_q, max_k, v.shape[-1]))
attn_mask = q.new_full((num_seqs, 1, max_q, max_k), float("-inf"))
for i, (sq_s, sq_e, sk_s, sk_e, q_len, k_len) in enumerate(
zip(q_starts, q_ends, k_starts, k_ends, q_lens, k_lens)
):
if q_len == 0 or k_len == 0:
continue
batch_q[i, :, :q_len] = q[sq_s:sq_e].transpose(0, 1)
batch_k[i, :, :k_len] = k[sk_s:sk_e].transpose(0, 1)
batch_v[i, :, :k_len] = v[sk_s:sk_e].transpose(0, 1)
attn_mask[i, :, :q_len, :k_len] = 0
if causal:
causal_mask = torch.triu(
torch.full((q_len, k_len), float("-inf"),
dtype=q.dtype, device=q.device),
diagonal=1,
)
attn_mask[i, :, :q_len, :k_len] = (
attn_mask[i, :, :q_len, :k_len] + causal_mask
)
batch_out = F.scaled_dot_product_attention(
batch_q,
batch_k,
batch_v,
attn_mask=attn_mask,
scale=softmax_scale,
is_causal=False,
)
for i, (sq_s, sq_e, q_len) in enumerate(zip(q_starts, q_ends, q_lens)):
if q_len == 0:
continue
out[sq_s:sq_e] = batch_out[i, :, :q_len].transpose(0, 1)
if return_lse:
scores = torch.matmul(
batch_q.to(torch.float32) * softmax_scale,
batch_k.transpose(-1, -2).to(torch.float32),
)
scores = scores + attn_mask.to(torch.float32)
batch_lse = torch.logsumexp(scores, dim=-1)
for i, (sq_s, sq_e, q_len) in enumerate(zip(q_starts, q_ends, q_lens)):
if q_len == 0:
continue
lse[:, sq_s:sq_e] = batch_lse[i, :, :q_len]

Copilot uses AI. Check for mistakes.
Comment on lines +9 to +20
import torch
import torch.nn.functional as F
from torch import nn

from aiter import gelu_tanh_and_mul
from aiter.dist.parallel_state import get_tp_group
from aiter.rotary_embedding import get_rope

from atom.config import Config, QuantizationConfig
from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig
from atom.model_loader.loader import load_model_in_plugin_mode
from atom.model_ops.base_attention import Attention
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.nn.functional as F and Gemma4Config are imported but not used in this module. Please remove unused imports to keep the file clean (and to avoid failing any lint/static checks).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants