Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
786dd86
fix(plugin): support Gemma 4 attention with head_dim > 256 fallback
ClementLinCF Apr 7, 2026
c9add87
fix(plugin): handle Gemma 4 heterogeneous head_dim in extend path
ClementLinCF Apr 7, 2026
bf992a1
feat: add Gemma 4 standalone support for ATOM server mode
ClementLinCF Apr 8, 2026
2a3e8f9
feat: use AITER rmsnorm kernel in _Gemma4RMSNorm, enable CUDA graphs
ClementLinCF Apr 8, 2026
bb4a53f
refactor: consolidate SDPA fallbacks, add logging to silent exceptions
ClementLinCF Apr 8, 2026
db594dc
fix(plugin): CUDA graph output corruption (accept_output_buffer)
ClementLinCF Apr 8, 2026
2cf338d
Register gemma4
MHYangAMD Apr 9, 2026
354fb32
Revert "fix(plugin): CUDA graph output corruption (accept_output_buff…
MHYangAMD Apr 9, 2026
db7ecfc
refactor: remove Gemma4 config fallback, use HF transformers native s…
ClementLinCF Apr 11, 2026
3fdac16
docs: restore deleted comments in model_runner.py KV cache logic
ClementLinCF Apr 11, 2026
320faa1
docs: restore deleted comments in attention_mha.py
ClementLinCF Apr 11, 2026
953e43f
style: align gemma4.py file header with project convention
ClementLinCF Apr 11, 2026
a618cf9
style: align gemma4_moe_gelu.py file header with project convention
ClementLinCF Apr 11, 2026
fa2f5af
cleanup: remove unused gemma4_ops_optimization.py
ClementLinCF Apr 11, 2026
c669a10
cleanup: remove unused gemma4_moe_gelu.py
ClementLinCF Apr 11, 2026
108145d
docs: restore deleted comment in attention_mha.py plugin path
ClementLinCF Apr 11, 2026
ac0fc6a
feat: add MXFP4 quantization support for Gemma4
ClementLinCF Apr 13, 2026
82a77d8
perf: fuse residual add + RMSNorm in Gemma4DecoderLayer
ClementLinCF Apr 13, 2026
2a52286
fix: use 256-byte K alignment for MXFP4 FP4 GEMM when 128-align has n…
ClementLinCF Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LayerQuantConfig,
get_quant_parser,
)
from atom.utils import envs, get_open_port
from atom.utils import envs, get_hf_text_config, get_open_port
from atom.utils.distributed.utils import stateless_init_torch_distributed_process_group
from torch.distributed import ProcessGroup, ReduceOp
from transformers import AutoConfig, GenerationConfig, PretrainedConfig
Expand Down Expand Up @@ -485,6 +485,7 @@ def _remap_layer_name(name: str) -> list[str]:
_MULTIMODAL_MODEL_TYPES: dict[str, str] = {
# Maps multimodal model_type -> key in config_dict for the text sub-config
"kimi_k25": "text_config",
"gemma4": "text_config",
}

# multimodal models fully supported by plugin mode
Expand All @@ -493,6 +494,7 @@ def _remap_layer_name(name: str) -> list[str]:
}



def get_hf_config(model: str, trust_remote_code: bool = False) -> PretrainedConfig:
config_dict, _ = PretrainedConfig.get_config_dict(
model,
Expand Down Expand Up @@ -793,6 +795,15 @@ def __post_init__(self):
self.hf_config = get_hf_config(
self.model, trust_remote_code=self.trust_remote_code
)
# For multimodal models (e.g. Gemma4), resolve to the text sub-config
# so engine code can access num_hidden_layers, head_dim, etc. directly.
_original_hf_config = self.hf_config
self.hf_config = get_hf_text_config(self.hf_config)
if _original_hf_config is not self.hf_config:
for attr in ("architectures", "model_type"):
orig_val = getattr(_original_hf_config, attr, None)
if orig_val is not None and getattr(self.hf_config, attr, None) is None:
setattr(self.hf_config, attr, orig_val)
# transformers 5+ exposes rope_parameters; <5 often only rope_scaling + rope_theta.
# Synthesize when missing or None so GPT-OSS YaRN (rope_type in rope_scaling) is preserved.
if getattr(self.hf_config, "rope_parameters", None) is None:
Expand Down
184 changes: 136 additions & 48 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"Qwen3NextForCausalLM": "atom.models.qwen3_next.Qwen3NextForCausalLM",
"KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM",
"MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM",
"Gemma4ForConditionalGeneration": "atom.models.gemma4.Gemma4ForCausalLM",
}
# seed = 34567
# np.random.seed(seed)
Expand Down Expand Up @@ -940,6 +941,34 @@ def _get_total_num_layers(self):
total += getattr(draft_hf, "num_nextn_predict_layers", 1)
return total

def _get_per_layer_kv_dims(self):
"""Return per-layer (num_kv_heads, head_dim) for heterogeneous models.

For models like Gemma 4 where sliding and global attention layers
have different num_kv_heads and head_dim, returns a list of
(num_kv_heads, head_dim) tuples, one per layer.
Returns None for homogeneous models.
"""
hf_config = self.config.hf_config
layer_types = getattr(hf_config, "layer_types", None)
global_head_dim = getattr(hf_config, "global_head_dim", None)
num_global_kv_heads = getattr(hf_config, "num_global_key_value_heads", None)
if layer_types is None or global_head_dim is None or num_global_kv_heads is None:
return None
if global_head_dim == hf_config.head_dim:
return None

ws = self.world_size
swa_kv = max(hf_config.num_key_value_heads // ws, 1)
glo_kv = max(num_global_kv_heads // ws, 1)
dims = []
Comment on lines +961 to +964
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_per_layer_kv_dims() computes per-rank KV heads via max(num_*_kv_heads // ws, 1) without the divisibility checks used in _get_num_kv_heads(). For unsupported TP sizes this can silently under-allocate KV cache and fail later in harder-to-debug ways; consider mirroring the same >= ws / % ws == 0 assertions for both sliding and global KV heads.

Copilot uses AI. Check for mistakes.
for lt in layer_types:
if lt == "full_attention":
dims.append((glo_kv, global_head_dim))
else:
dims.append((swa_kv, hf_config.head_dim))
return dims

def _compute_block_bytes(self):
"""Compute the TRUE per-block memory cost including all tensors.

Expand Down Expand Up @@ -1008,25 +1037,34 @@ def _compute_block_bytes(self):
)
block_bytes += self.num_gdn_attn_state * one_layer_byte
else:
# Standard attention: kv_cache [2, num_hidden_layers, blocks, ...]
# Note: allocate_kv_cache uses hf_config.num_hidden_layers for
# the standard path (draft layers use separate binding).
block_bytes = (
2
* hf_config.num_hidden_layers
* self.block_size
* num_kv_heads
* hf_config.head_dim
* kv_dtype_size
)
# kv_scale: [2, num_hidden_layers, blocks, kv_heads, phys_block_size]
block_bytes += (
2
* hf_config.num_hidden_layers
* num_kv_heads
* self.physical_block_size
* 4 # float32
)
per_layer_dims = self._get_per_layer_kv_dims()
if per_layer_dims is not None:
# Heterogeneous attention (e.g. Gemma 4): each layer may have
# different num_kv_heads / head_dim, so sum per-layer costs.
block_bytes = 0
for kv_h, hd in per_layer_dims:
block_bytes += 2 * self.block_size * kv_h * hd * kv_dtype_size
block_bytes += 2 * kv_h * self.physical_block_size * 4
else:
# Standard attention: kv_cache [2, num_hidden_layers, blocks, ...]
# Note: allocate_kv_cache uses hf_config.num_hidden_layers for
# the standard path (draft layers use separate binding).
block_bytes = (
2
* hf_config.num_hidden_layers
* self.block_size
* num_kv_heads
* hf_config.head_dim
* kv_dtype_size
)
# kv_scale: [2, num_hidden_layers, blocks, kv_heads, phys_block_size]
block_bytes += (
2
* hf_config.num_hidden_layers
* num_kv_heads
* self.physical_block_size
* 4 # float32
)
return block_bytes

def _estimate_cudagraph_overhead(self):
Expand Down Expand Up @@ -1205,16 +1243,37 @@ def allocate_kv_cache(self, num_kvcache_blocks):
device="cuda",
)
else:
self.kv_cache = torch.zeros(
2,
hf_config.num_hidden_layers,
self.num_physical_kvcache_blocks,
self.physical_block_size,
num_kv_heads,
hf_config.head_dim,
dtype=dtypes.d_dtypes[config.kv_cache_dtype],
device="cuda",
)
per_layer_dims = self._get_per_layer_kv_dims()
if per_layer_dims is not None:
self._per_layer_kv_cache = []
self._per_layer_kv_scale = []
kv_dt = dtypes.d_dtypes[config.kv_cache_dtype]
for kv_h, hd in per_layer_dims:
kc = torch.zeros(
2, self.num_physical_kvcache_blocks,
self.physical_block_size, kv_h, hd,
dtype=kv_dt, device="cuda",
)
sc = torch.zeros(
2, self.num_physical_kvcache_blocks,
kv_h, self.physical_block_size,
dtype=dtypes.fp32, device="cuda",
)
self._per_layer_kv_cache.append(kc)
self._per_layer_kv_scale.append(sc)
self.kv_cache = self._per_layer_kv_cache[0]
else:
self._per_layer_kv_cache = None
self.kv_cache = torch.zeros(
2,
hf_config.num_hidden_layers,
self.num_physical_kvcache_blocks,
self.physical_block_size,
num_kv_heads,
hf_config.head_dim,
dtype=dtypes.d_dtypes[config.kv_cache_dtype],
device="cuda",
)

self.kv_scale = torch.zeros(
2,
Expand All @@ -1224,9 +1283,10 @@ def allocate_kv_cache(self, num_kvcache_blocks):
self.physical_block_size,
dtype=dtypes.fp32,
device="cuda",
)
) if self._per_layer_kv_cache is None else None

# Build KVCacheConfig
# lirong TODO: This is a simple solution to build KVCacheConfig,
# TODO(lirong): This is a simple solution to build KVCacheConfig,
# models with only one type of attention, but not support multi-type of attention models.
# We need to support it by kv_cache_group in the future.

Expand All @@ -1237,7 +1297,12 @@ def allocate_kv_cache(self, num_kvcache_blocks):

kv_cache_tensors = []
layer_id = 0
x = 16 // self.kv_cache.element_size()
_elem_size = (
self.kv_cache.element_size()
if not isinstance(self.kv_cache, list) and hasattr(self.kv_cache, "element_size")
else (self._per_layer_kv_cache[0].element_size() if self._per_layer_kv_cache else 2)
)
x = 16 // _elem_size
for model_name, model in models_to_bind:
logger.info(
f"Binding KV cache for {model_name} model starting at layer_id={layer_id}"
Expand All @@ -1252,23 +1317,46 @@ def allocate_kv_cache(self, num_kvcache_blocks):
attn_idx = layer_id // self.full_attention_interval
else:
attn_idx = layer_id
k_cache = self.kv_cache[0, attn_idx].view(
self.num_physical_kvcache_blocks,
num_kv_heads,
hf_config.head_dim // x,
self.physical_block_size,
x,
)
v_cache = self.kv_cache[1, attn_idx].view(
self.num_physical_kvcache_blocks,
num_kv_heads,
hf_config.head_dim,
self.physical_block_size,
)

if self._per_layer_kv_cache is not None:
layer_kv = self._per_layer_kv_cache[attn_idx]
layer_kv_h = layer_kv.shape[3]
layer_hd = layer_kv.shape[4]
k_cache = layer_kv[0].view(
self.num_physical_kvcache_blocks,
layer_kv_h,
layer_hd // x,
self.physical_block_size,
x,
)
v_cache = layer_kv[1].view(
self.num_physical_kvcache_blocks,
layer_kv_h,
layer_hd,
self.physical_block_size,
)
if config.kv_cache_dtype == "fp8":
module.k_scale = self._per_layer_kv_scale[attn_idx][0]
module.v_scale = self._per_layer_kv_scale[attn_idx][1]
else:
k_cache = self.kv_cache[0, attn_idx].view(
self.num_physical_kvcache_blocks,
num_kv_heads,
hf_config.head_dim // x,
self.physical_block_size,
x,
)
v_cache = self.kv_cache[1, attn_idx].view(
self.num_physical_kvcache_blocks,
num_kv_heads,
hf_config.head_dim,
self.physical_block_size,
)
if config.kv_cache_dtype == "fp8":
module.k_scale = self.kv_scale[0, attn_idx]
module.v_scale = self.kv_scale[1, attn_idx]

module.max_model_len = self.config.max_model_len
if config.kv_cache_dtype == "fp8":
module.k_scale = self.kv_scale[0, attn_idx]
module.v_scale = self.kv_scale[1, attn_idx]

k_scale = module.k_scale
v_scale = module.v_scale
Expand Down
51 changes: 35 additions & 16 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
else 1.0
)
self.kv_scale = torch.tensor(self.kv_scale_float, dtype=torch.float32)
self.per_tensor_scale = self.kv_scale
self.per_token_quant = True
self.sinks = sinks
self.sliding_window = sliding_window if sliding_window is not None else -1
Expand Down Expand Up @@ -100,6 +101,7 @@ def forward_impl_server_mode(
return o

o: torch.Tensor

q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
Expand Down Expand Up @@ -450,33 +452,50 @@ def paged_attention_persistent_asm(

return output

@staticmethod
def _sdpa_varlen_fallback(q, k, v, cu_seqlens_q, cu_seqlens_k, softmax_scale, causal):
"""SDPA fallback for head_dim > 256 where CK is unsupported."""
from atom.plugin.attention_mha import _sdpa_varlen_attn
return _sdpa_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_k,
softmax_scale, causal)

@mark_trace(prefix="prefill_attention", torch_compile=False)
def prefill_attention(
self, q, k, v, k_cache, v_cache, k_scale, v_scale, fwd_ctx: ForwardContext
):

# variable lenth attention use key value as input
# variable length attention use key value as input
attn_metadata = fwd_ctx.attn_metadata
sliding_window = (
(self.sliding_window, 0, 0)
if self.sliding_window is not None
else (-1, -1, 0)
)
o = aiter.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
min_seqlen_q=attn_metadata.min_seqlen_q,
dropout_p=attn_metadata.dropout_p,
softmax_scale=self.scale,
causal=True,
window_size=sliding_window,
sink_ptr=self.sinks,
)

if self.head_dim > 256:
o = self._sdpa_varlen_fallback(
q, k, v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
softmax_scale=self.scale,
causal=True,
)
else:
o = aiter.flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
min_seqlen_q=attn_metadata.min_seqlen_q,
dropout_p=attn_metadata.dropout_p,
softmax_scale=self.scale,
causal=True,
window_size=sliding_window,
sink_ptr=self.sinks,
)

return o

Expand Down
Loading