feat: add Gemma4 31B support for standalone and vLLM plugin mode on MI355X#546
feat: add Gemma4 31B support for standalone and vLLM plugin mode on MI355X#546ClementLinCF wants to merge 18 commits intomainfrom
Conversation
- Rebase AiterBackend onto vLLM v1 AttentionBackend so plugin passes the v1 attention selector's method checks - Add Triton FA2 / PyTorch SDPA fallback for prefill when head_dim > 256 (CK flash_attn_varlen_func limitation) - Guard qkv and position slicing against None values
Gemma 4 uses head_dim=256 for sliding window layers and head_dim=512 for full attention layers. Two issues fixed: 1. swa_workspace was allocated with model-global head_dim (512) from get_head_size(), causing assertion failure in cp_mha_gather_cache when the KV cache for sliding window layers has head_dim=256. Fix: derive swa_head_dim from actual sliding window layer impl. 2. CK flash attention kernel hard-limits head_dim to 256, crashing extend_forward for full attention layers under high concurrency. Fix: fallback to torch SDPA with GQA head repeating + LSE computation when head_dim > 256.
Enable Gemma 4 31B-IT to run in ATOM standalone (Config A) on MI355X with TP=8. This required several model-specific adaptations: - New model files: atom/models/gemma4.py (Attention, MLP, DecoderLayer, Model, ForCausalLM) and atom/model_config/gemma4.py (Gemma4TextConfig) - RMSNorm: use standard x*weight formula (not Gemma1/2 x*(1+weight)) - v_norm: add missing value normalization (RMSNorm without scale) - k_eq_v: apply v_norm to raw k before k_norm for global attention - layer_scalar: apply once at end of layer (not twice to sub-blocks) - Attention scaling: use 1.0 (not 1/sqrt(head_dim)) since q/k norms already control magnitudes - Config: resolve gemma4 multimodal config and gemma4_text sub-config - KV cache: per-layer allocation for heterogeneous head_dim (512 global vs 256 sliding) - Prefill attention: SDPA fallback for head_dim > 256 (CK limitation)
_Gemma4RMSNorm now dispatches to aiter.rmsnorm2d_fwd when with_scale=True, falling back to pure PyTorch for with_scale=False (v_norm). This is enabled by the upstream aiter rmsnorm dtype fix (5df37c1) which auto-casts weight to input dtype, preventing the silent data corruption that previously blocked AITER kernel usage. With this change, --enforce-eager is no longer required — CUDA graph capture works correctly and provides ~2x decode throughput improvement.
- Extract _sdpa_varlen_attn as single shared SDPA varlen fallback (replaces 3 near-identical copies across plugin and model_ops) - Log warning when Triton FA2 fails and falls back to SDPA - Log warning when vLLM v1 AttentionBackend import fails in AiterBackendDecoratorForPluginMode - Clean up commit hash reference in _Gemma4RMSNorm docstring
Set accept_output_buffer=True in vllmAiterAttentionBackendMethods so vLLM pre-allocates the output buffer at a fixed address for CUDA graph replay. Pass the buffer through forward() → forward_impl_plugin_mode() and use it when available instead of torch.empty(). Root cause: PIECEWISE CUDA graph pieces read output from the address captured during warmup, but ATOM allocated a new tensor each call → stale data → garbage output. Three files changed: - atom/plugin/attention.py: accept_output_buffer = True - atom/model_ops/attention_mha.py: pass output= to plugin forward - atom/plugin/attention_mha.py: conditional output buffer allocation
…er)" This reverts commit db594dc.
…upport HF transformers >= 5.5.3 natively supports Gemma4Config/Gemma4TextConfig. Remove the ATOM-side fallback config and use HF directly. - Delete atom/model_config/gemma4.py (226 lines, no longer needed) - Remove _resolve_atom_text_config() try/except fallback in config.py - Update atom/models/gemma4.py to import from transformers.models.gemma4
Re-add comments that were inadvertently removed during the Gemma 4 heterogeneous attention support (bf992a1), and add a new comment explaining the per-layer branch for clarity.
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Planning/documentation file that is not imported by any runtime code. Made-with: Cursor
Incomplete experimental kernel, not imported by any runtime code. Made-with: Cursor
Made-with: Cursor
| from aiter.rotary_embedding import get_rope | ||
|
|
||
| from atom.config import Config, QuantizationConfig | ||
| from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig |
There was a problem hiding this comment.
transformers.models.gemma4.configuration_gemma4.Gemma4Config imported but unused
| from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig | |
| from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig |
| import torch.nn.functional as F | ||
|
|
There was a problem hiding this comment.
Pull request overview
Adds Gemma 4 31B-it support to ATOM on MI355X across both standalone server mode and vLLM plugin mode, including required handling for Gemma4’s heterogeneous attention head dimensions and FlashAttention head-dim limits.
Changes:
- Introduces an inference-only Gemma4 text backbone implementation and registers it for vLLM plugin usage.
- Extends KV-cache sizing/allocation/binding to support per-layer
(num_kv_heads, head_dim)heterogeneity. - Adds attention fallbacks for
head_dim > 256(CK limitation) in both standalone and plugin attention paths, plus multimodal text-config extraction.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/plugin/vllm/register.py | Registers Gemma4 arch to use the ATOM causal LM wrapper in vLLM plugin mode. |
| atom/plugin/vllm/model_wrapper.py | Maps Gemma4 HF architecture name to the ATOM Gemma4 implementation class. |
| atom/plugin/attention.py | Tracks SWA head-dim for workspace sizing; attempts to rebase backend class onto vLLM v1 AttentionBackend. |
| atom/plugin/attention_mha.py | Adds CK→(Triton/SDPA) fallback paths for varlen attention when head_dim > 256, plus some null-guarding. |
| atom/models/gemma4.py | New Gemma4 inference-only model implementation with AITER-optimized components. |
| atom/model_ops/attention_mha.py | Adds SDPA fallback for standalone prefill varlen attention when head_dim > 256. |
| atom/model_engine/model_runner.py | Adds per-layer KV-dim discovery and uses it for heterogeneous KV cache allocation/binding and memory estimation. |
| atom/config.py | Resolves multimodal configs to text_config early so engine code can directly access text attributes. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Rebase onto vLLM v1 AttentionBackend so cls inherits default method | ||
| # stubs (get_preferred_block_size, validate_configuration, supports_*, | ||
| # etc.) that the vLLM v1 attention selector requires. | ||
| try: | ||
| from vllm.v1.attention.backend import ( | ||
| AttentionBackend as _VllmAttnBackend, | ||
| ) | ||
| cls.__bases__ = (_VllmAttnBackend,) | ||
| except ImportError: | ||
| logger.warning( | ||
| "vllm.v1.attention.backend not found; %s will not inherit " | ||
| "vLLM v1 AttentionBackend stubs — attention selector may " | ||
| "reject this backend.", cls.__name__) |
There was a problem hiding this comment.
cls.__bases__ = (_VllmAttnBackend,) can raise TypeError (e.g., incompatible layout / metaclass) but the code only catches ImportError, which would crash plugin initialization at import time. Consider avoiding runtime base-class mutation (e.g., create/return a new subclass with the desired bases) or at least catch Exception/TypeError and fall back gracefully.
| # Rebase onto vLLM v1 AttentionBackend so cls inherits default method | |
| # stubs (get_preferred_block_size, validate_configuration, supports_*, | |
| # etc.) that the vLLM v1 attention selector requires. | |
| try: | |
| from vllm.v1.attention.backend import ( | |
| AttentionBackend as _VllmAttnBackend, | |
| ) | |
| cls.__bases__ = (_VllmAttnBackend,) | |
| except ImportError: | |
| logger.warning( | |
| "vllm.v1.attention.backend not found; %s will not inherit " | |
| "vLLM v1 AttentionBackend stubs — attention selector may " | |
| "reject this backend.", cls.__name__) | |
| # Create a subclass that also inherits the vLLM v1 AttentionBackend | |
| # so cls gets the default method stubs (get_preferred_block_size, | |
| # validate_configuration, supports_*, etc.) required by the vLLM v1 | |
| # attention selector. Avoid mutating __bases__ at runtime because it | |
| # can raise TypeError for incompatible layouts/metaclasses. | |
| try: | |
| from vllm.v1.attention.backend import ( | |
| AttentionBackend as _VllmAttnBackend, | |
| ) | |
| if not issubclass(cls, _VllmAttnBackend): | |
| cls = type( | |
| cls.__name__, | |
| (cls, _VllmAttnBackend), | |
| { | |
| "__module__": cls.__module__, | |
| "__qualname__": cls.__qualname__, | |
| "__doc__": cls.__doc__, | |
| }, | |
| ) | |
| except ImportError: | |
| logger.warning( | |
| "vllm.v1.attention.backend not found; %s will not inherit " | |
| "vLLM v1 AttentionBackend stubs — attention selector may " | |
| "reject this backend.", cls.__name__) | |
| except TypeError as exc: | |
| logger.warning( | |
| "Unable to create vLLM AttentionBackend-compatible subclass " | |
| "for %s; keeping original class without vLLM v1 stubs: %s", | |
| cls.__name__, | |
| exc, | |
| ) |
| ws = self.world_size | ||
| swa_kv = max(hf_config.num_key_value_heads // ws, 1) | ||
| glo_kv = max(num_global_kv_heads // ws, 1) | ||
| dims = [] |
There was a problem hiding this comment.
_get_per_layer_kv_dims() computes per-rank KV heads via max(num_*_kv_heads // ws, 1) without the divisibility checks used in _get_num_kv_heads(). For unsupported TP sizes this can silently under-allocate KV cache and fail later in harder-to-debug ways; consider mirroring the same >= ws / % ws == 0 assertions for both sliding and global KV heads.
| ki = k[sk_s:sk_e].transpose(0, 1).unsqueeze(0) | ||
| vi = v[sk_s:sk_e].transpose(0, 1).unsqueeze(0) | ||
| num_q_heads, num_kv_heads = qi.shape[1], ki.shape[1] | ||
| if num_q_heads != num_kv_heads: |
There was a problem hiding this comment.
In the GQA expansion path, rep = num_q_heads // num_kv_heads is used without asserting num_q_heads % num_kv_heads == 0. If they’re not divisible this will either produce incorrect head replication or trigger shape errors in SDPA; add an explicit divisibility assertion (or handle the remainder) before repeating heads.
| if num_q_heads != num_kv_heads: | |
| if num_q_heads != num_kv_heads: | |
| if num_q_heads % num_kv_heads != 0: | |
| raise ValueError( | |
| f"GQA head expansion requires num_q_heads to be divisible " | |
| f"by num_kv_heads, got num_q_heads={num_q_heads} and " | |
| f"num_kv_heads={num_kv_heads}" | |
| ) |
1. Quark exclude list name remapping: Quark uses `model.language_model.*` naming but ATOM uses `model.*`. Added `quant_exclude_name_mapping` to Gemma4ForCausalLM and passed it through model_wrapper.py so self_attn layers are correctly identified as excluded (BF16). 2. ASM FP4 GEMM K-dimension alignment: The ASM kernel `f4gemm_bf16_per1x32Fp4_BpreShuffle` requires K_packed to be a multiple of 128. At TP=8, down_proj has K_packed=1344 (not aligned). Added weight K padding in process_weights_after_loading and runtime input K padding in forward to align to 128. 3. Triton preshuffle M-dimension padding: When M < 32, pad input x and x_scale to prevent out-of-bounds reads in the preshuffle GEMM path.
Use aiter's rmsnorm2d_fwd_with_add to fuse the residual addition and RMSNorm into a single kernel, eliminating one memory round-trip per fusion point. Two fusion points per layer (input_layernorm and pre_feedforward_layernorm) × 60 layers = 120 fewer kernel launches.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| xs_padded = torch.zeros(m_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device) | ||
| xs_padded[:m] = x_scale |
There was a problem hiding this comment.
In the gemm_afp4wfp4_preshuffle path, the new padding logic reshapes x_scale to a block-major shape (view(x_scale.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1)) and then pads it as if it were token-major (xs_padded = torch.zeros(m_padded, ...) and xs_padded[:m] = x_scale). When m is not a multiple of MXFP4_QUANT_BLOCK_SIZE, this will raise a shape mismatch at runtime and/or feed incorrectly-shaped scales into the preshuffle kernel. Consider padding in the same representation expected by the kernel (e.g., pad before reshaping, or pad the block dimension to ceil(m/BS)), and avoid mixing token- vs block-shaped scales.
| xs_padded = torch.zeros(m_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device) | |
| xs_padded[:m] = x_scale | |
| if m >= MXFP4_QUANT_BLOCK_SIZE: | |
| scale_rows_padded = m_padded // MXFP4_QUANT_BLOCK_SIZE | |
| xs_padded = torch.zeros( | |
| scale_rows_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device | |
| ) | |
| xs_padded[: x_scale.shape[0]] = x_scale | |
| else: | |
| xs_padded = torch.zeros( | |
| m_padded, x_scale.shape[-1], dtype=torch.uint8, device=x.device | |
| ) | |
| xs_padded[:m] = x_scale |
| k_pad = self._fp4_k_padded | ||
| elem_size = x.element_size() | ||
| fp4_elem_size = 1 | ||
| x_logical_k = x.shape[-1] * elem_size // fp4_elem_size |
There was a problem hiding this comment.
x_logical_k is computed but never used. If it’s not needed for subsequent padding logic, remove it to avoid confusion; if it is needed, incorporate it into the computation (or add an assertion/comment explaining the intent).
| x_logical_k = x.shape[-1] * elem_size // fp4_elem_size |
| # Rebase onto vLLM v1 AttentionBackend so cls inherits default method | ||
| # stubs (get_preferred_block_size, validate_configuration, supports_*, | ||
| # etc.) that the vLLM v1 attention selector requires. | ||
| try: | ||
| from vllm.v1.attention.backend import ( | ||
| AttentionBackend as _VllmAttnBackend, | ||
| ) | ||
| cls.__bases__ = (_VllmAttnBackend,) |
There was a problem hiding this comment.
Mutating cls.__bases__ at runtime is fragile (can raise TypeError if base layouts are incompatible) and it also drops the original ATOM AttentionBackend base class from the MRO. It would be safer to define a small adapter subclass that inherits from vllm.v1.attention.backend.AttentionBackend (or uses multiple inheritance) rather than rebasing an existing class in-place.
| # Rebase onto vLLM v1 AttentionBackend so cls inherits default method | |
| # stubs (get_preferred_block_size, validate_configuration, supports_*, | |
| # etc.) that the vLLM v1 attention selector requires. | |
| try: | |
| from vllm.v1.attention.backend import ( | |
| AttentionBackend as _VllmAttnBackend, | |
| ) | |
| cls.__bases__ = (_VllmAttnBackend,) | |
| # Return a small adapter subclass instead of rebasing cls in-place. | |
| # This preserves the original ATOM backend in the MRO while also | |
| # inheriting the default vLLM v1 AttentionBackend method stubs | |
| # (get_preferred_block_size, validate_configuration, supports_*, | |
| # etc.) required by the vLLM v1 attention selector. | |
| try: | |
| from vllm.v1.attention.backend import ( | |
| AttentionBackend as _VllmAttnBackend, | |
| ) | |
| if not issubclass(cls, _VllmAttnBackend): | |
| cls = type( | |
| cls.__name__, | |
| (cls, _VllmAttnBackend), | |
| { | |
| "__module__": cls.__module__, | |
| "__doc__": cls.__doc__, | |
| "__qualname__": cls.__qualname__, | |
| }, | |
| ) |
| for i in range(num_seqs): | ||
| sq_s, sq_e = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item() | ||
| sk_s, sk_e = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item() | ||
| qi = q[sq_s:sq_e].transpose(0, 1).unsqueeze(0) | ||
| ki = k[sk_s:sk_e].transpose(0, 1).unsqueeze(0) | ||
| vi = v[sk_s:sk_e].transpose(0, 1).unsqueeze(0) | ||
| num_q_heads, num_kv_heads = qi.shape[1], ki.shape[1] | ||
| if num_q_heads != num_kv_heads: | ||
| rep = num_q_heads // num_kv_heads | ||
| ki = ki.repeat_interleave(rep, dim=1) | ||
| vi = vi.repeat_interleave(rep, dim=1) | ||
| oi = F.scaled_dot_product_attention( | ||
| qi, ki, vi, scale=softmax_scale, is_causal=causal, | ||
| ) | ||
| out[sq_s:sq_e] = oi.squeeze(0).transpose(0, 1) | ||
|
|
||
| if return_lse: | ||
| scores = torch.matmul(qi * softmax_scale, ki.transpose(-1, -2)) | ||
| if causal and (sq_e - sq_s) == (sk_e - sk_s): | ||
| mask = torch.triu( | ||
| torch.full_like(scores, float("-inf")), diagonal=1) | ||
| scores = scores + mask | ||
| lse[:, sq_s:sq_e] = torch.logsumexp(scores, dim=-1).squeeze(0) | ||
|
|
There was a problem hiding this comment.
_sdpa_varlen_attn iterates over sequences in Python and calls scaled_dot_product_attention per-sequence (plus an extra matmul/logsumexp when return_lse=True). For large batches (many sequences) this fallback will be extremely slow and can dominate prefill/extend latency when head_dim > 256. Consider a more vectorized fallback (e.g., pad to max_seqlen and run SDPA in a single call, or use a compiled/kernel varlen implementation when available) to keep the fallback usable in production.
| for i in range(num_seqs): | |
| sq_s, sq_e = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item() | |
| sk_s, sk_e = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item() | |
| qi = q[sq_s:sq_e].transpose(0, 1).unsqueeze(0) | |
| ki = k[sk_s:sk_e].transpose(0, 1).unsqueeze(0) | |
| vi = v[sk_s:sk_e].transpose(0, 1).unsqueeze(0) | |
| num_q_heads, num_kv_heads = qi.shape[1], ki.shape[1] | |
| if num_q_heads != num_kv_heads: | |
| rep = num_q_heads // num_kv_heads | |
| ki = ki.repeat_interleave(rep, dim=1) | |
| vi = vi.repeat_interleave(rep, dim=1) | |
| oi = F.scaled_dot_product_attention( | |
| qi, ki, vi, scale=softmax_scale, is_causal=causal, | |
| ) | |
| out[sq_s:sq_e] = oi.squeeze(0).transpose(0, 1) | |
| if return_lse: | |
| scores = torch.matmul(qi * softmax_scale, ki.transpose(-1, -2)) | |
| if causal and (sq_e - sq_s) == (sk_e - sk_s): | |
| mask = torch.triu( | |
| torch.full_like(scores, float("-inf")), diagonal=1) | |
| scores = scores + mask | |
| lse[:, sq_s:sq_e] = torch.logsumexp(scores, dim=-1).squeeze(0) | |
| # Expand KV heads once, then execute a single batched SDPA over padded | |
| # tensors instead of launching one SDPA per sequence from Python. | |
| num_kv_heads = k.shape[1] | |
| if num_heads_q != num_kv_heads: | |
| rep = num_heads_q // num_kv_heads | |
| k = k.repeat_interleave(rep, dim=1) | |
| v = v.repeat_interleave(rep, dim=1) | |
| q_starts = cu_seqlens_q[:-1].tolist() | |
| q_ends = cu_seqlens_q[1:].tolist() | |
| k_starts = cu_seqlens_k[:-1].tolist() | |
| k_ends = cu_seqlens_k[1:].tolist() | |
| q_lens = [q_e - q_s for q_s, q_e in zip(q_starts, q_ends)] | |
| k_lens = [k_e - k_s for k_s, k_e in zip(k_starts, k_ends)] | |
| max_q = max(q_lens) if q_lens else 0 | |
| max_k = max(k_lens) if k_lens else 0 | |
| if max_q == 0 or max_k == 0: | |
| if return_lse: | |
| lse.zero_() | |
| return out, lse | |
| return out | |
| batch_q = q.new_zeros((num_seqs, num_heads_q, max_q, q.shape[-1])) | |
| batch_k = k.new_zeros((num_seqs, num_heads_q, max_k, k.shape[-1])) | |
| batch_v = v.new_zeros((num_seqs, num_heads_q, max_k, v.shape[-1])) | |
| attn_mask = q.new_full((num_seqs, 1, max_q, max_k), float("-inf")) | |
| for i, (sq_s, sq_e, sk_s, sk_e, q_len, k_len) in enumerate( | |
| zip(q_starts, q_ends, k_starts, k_ends, q_lens, k_lens) | |
| ): | |
| if q_len == 0 or k_len == 0: | |
| continue | |
| batch_q[i, :, :q_len] = q[sq_s:sq_e].transpose(0, 1) | |
| batch_k[i, :, :k_len] = k[sk_s:sk_e].transpose(0, 1) | |
| batch_v[i, :, :k_len] = v[sk_s:sk_e].transpose(0, 1) | |
| attn_mask[i, :, :q_len, :k_len] = 0 | |
| if causal: | |
| causal_mask = torch.triu( | |
| torch.full((q_len, k_len), float("-inf"), | |
| dtype=q.dtype, device=q.device), | |
| diagonal=1, | |
| ) | |
| attn_mask[i, :, :q_len, :k_len] = ( | |
| attn_mask[i, :, :q_len, :k_len] + causal_mask | |
| ) | |
| batch_out = F.scaled_dot_product_attention( | |
| batch_q, | |
| batch_k, | |
| batch_v, | |
| attn_mask=attn_mask, | |
| scale=softmax_scale, | |
| is_causal=False, | |
| ) | |
| for i, (sq_s, sq_e, q_len) in enumerate(zip(q_starts, q_ends, q_lens)): | |
| if q_len == 0: | |
| continue | |
| out[sq_s:sq_e] = batch_out[i, :, :q_len].transpose(0, 1) | |
| if return_lse: | |
| scores = torch.matmul( | |
| batch_q.to(torch.float32) * softmax_scale, | |
| batch_k.transpose(-1, -2).to(torch.float32), | |
| ) | |
| scores = scores + attn_mask.to(torch.float32) | |
| batch_lse = torch.logsumexp(scores, dim=-1) | |
| for i, (sq_s, sq_e, q_len) in enumerate(zip(q_starts, q_ends, q_lens)): | |
| if q_len == 0: | |
| continue | |
| lse[:, sq_s:sq_e] = batch_lse[i, :, :q_len] |
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import nn | ||
|
|
||
| from aiter import gelu_tanh_and_mul | ||
| from aiter.dist.parallel_state import get_tp_group | ||
| from aiter.rotary_embedding import get_rope | ||
|
|
||
| from atom.config import Config, QuantizationConfig | ||
| from transformers.models.gemma4.configuration_gemma4 import Gemma4Config, Gemma4TextConfig | ||
| from atom.model_loader.loader import load_model_in_plugin_mode | ||
| from atom.model_ops.base_attention import Attention |
There was a problem hiding this comment.
torch.nn.functional as F and Gemma4Config are imported but not used in this module. Please remove unused imports to keep the file clean (and to avoid failing any lint/static checks).
Motivation
Add Gemma 4 31B-it support for ATOM on AMD Instinct MI355X (gfx950 / CDNA4), covering both Config A (ATOM standalone server) and Config B (vLLM + ATOM OOT plugin).
Gemma 4 introduces architectural features not present in other supported models: heterogeneous attention layers with different head_dim (sliding=128, global=256), per-layer-type RoPE configurations, and head_dim > 256 which exceeds the CK Flash Attention limit. This PR handles all of these while maintaining backward compatibility with existing models (LLaMA, DeepSeek, Kimi-K2.5, Qwen, etc.).
Technical Details
New: Gemma 4 model implementation
Heterogeneous head_dim support
CK Flash Attention head_dim > 256 fallback
atom/model_ops/attention_mha.py: Added _sdpa_varlen_fallback() — PyTorch SDPA fallback with GQA head expansion for prefill when head_dim > 256 (CK limitation). Original CK path unchanged for head_dim <= 256.
atom/plugin/attention_mha.py: Added _flash_attn_varlen_with_fallback() and _triton_flash_attn_varlen() for the plugin extend/prefill paths. Automatic dispatch based on head_dim. Also added qkv is not None null checks for Gemma 4 code paths.
Plugin mode (vLLM) support
atom/plugin/attention.py: Added swa_head_dim tracking for heterogeneous sliding window attention workspace allocation. Rebased plugin onto vLLM v1 AttentionBackend for API compatibility.
atom/plugin/vllm/register.py and atom/plugin/vllm/model_wrapper.py: Registered Gemma4ForConditionalGeneration in both model registries (+1 line each).
Multimodal config support
Dependency
Test Plan
ATOM Plugin:
vllm serve "$MODEL"
--host 0.0.0.0 --port "$PORT"
--tensor-parallel-size "$TP"
--trust-remote-code
--max-model-len 4096
> /tmp/vllm_plugin_server.log 2>&1 &
ATOM standalone:
python -m atom.entrypoints.openai_server
--model "$MODEL"
--tensor-parallel-size "$TP"
--max-model-len 4096
$EAGER_FLAG
> /tmp/atom_server.log 2>&1 &
Test Result
Concurrency = 2 (low concurrency)
Concurrency = 40 (high concurrency)
Submission Checklist