Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions atom/model_ops/mamba_ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,7 @@ def causal_conv1d_update(
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
validate_data=False,
return_packed_qkv: bool = False,
):
"""
x: Input tensor which can take the following shapes:
Expand Down Expand Up @@ -1289,9 +1290,20 @@ def causal_conv1d_update(

# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
num_tokens = x.shape[0]
query = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device)
key = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device)
value = torch.empty([num_tokens, v_dim_size, 1], dtype=x.dtype, device=x.device)
packed_qkv = None
if return_packed_qkv:
packed_qkv = torch.empty(
[num_tokens, k_dim_size * 2 + v_dim_size],
dtype=x.dtype,
device=x.device,
)
query = packed_qkv[:, :k_dim_size].unsqueeze(-1)
key = packed_qkv[:, k_dim_size : k_dim_size * 2].unsqueeze(-1)
value = packed_qkv[:, k_dim_size * 2 :].unsqueeze(-1)
else:
query = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device)
key = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device)
value = torch.empty([num_tokens, v_dim_size, 1], dtype=x.dtype, device=x.device)

stride_q_seq, stride_q_dim, stride_q_token = query.stride()
stride_k_seq, stride_k_dim, stride_k_token = key.stride()
Expand Down Expand Up @@ -1383,6 +1395,9 @@ def grid(META):
)
if unsqueeze:
out = out.squeeze(-1)
if return_packed_qkv:
return packed_qkv

query = query.squeeze(-1)
key = key.squeeze(-1)
value = value.squeeze(-1)
Expand Down
79 changes: 79 additions & 0 deletions atom/plugin/vllm/attention_backend/attention_gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
)

# from atom.model_ops.attentions.gdn_attn import GDNAttentionMetadata
from vllm import envs as vllm_envs
from vllm.forward_context import get_forward_context
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
)
from atom.model_ops.fla_ops.fused_sigmoid_gating import (
fused_sigmoid_gating_delta_rule_update,
Expand Down Expand Up @@ -162,6 +164,9 @@ def __init__(
self.head_k_dim = head_k_dim
self.head_v_dim = head_v_dim
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.enable_packed_recurrent_decode = (
vllm_envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)

def rearrange_mixed_qkv(self, mixed_qkv):
if mixed_qkv is None:
Expand All @@ -182,6 +187,52 @@ def rearrange_mixed_qkv(self, mixed_qkv):
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
return query.contiguous(), key.contiguous(), value.contiguous()

def _forward_core_decode_non_spec(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
conv_weights: torch.Tensor,
) -> torch.Tensor:
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
num_actual_tokens = attn_metadata.num_actual_tokens

mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]

mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.num_k_heads * self.head_k_dim // self.tp_size,
self.num_v_heads * self.head_v_dim // self.tp_size,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
validate_data=False,
return_packed_qkv=True,
)

out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
Comment thread
zejunchen-zejun marked this conversation as resolved.
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv_non_spec,
a=a,
b=b,
A_log=self.A_log,
dt_bias=self.dt_bias,
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
use_qk_l2norm_in_kernel=True,
)
return core_attn_out

def forward(
self,
mixed_qkv: torch.Tensor,
Expand Down Expand Up @@ -217,6 +268,17 @@ def forward(
) # noqa: E501
compilation_config = forward_context.no_compile_layers
self_kv_cache = compilation_config[layer_name].kv_cache
virtual_engine = getattr(forward_context, "virtual_engine", None)
# vLLM <= 0.17 exposed per-virtual-engine KV caches via
# forward_context.virtual_engine. vLLM 0.19 no longer sets that field
# for this path and the layer cache is already the active cache tuple.
if (
virtual_engine is not None
and isinstance(self_kv_cache, (list, tuple))
and self_kv_cache
and isinstance(self_kv_cache[0], (list, tuple))
):
self_kv_cache = self_kv_cache[virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
Expand All @@ -231,6 +293,23 @@ def forward(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)

if (
self.enable_packed_recurrent_decode
and spec_sequence_masks is None
and attn_metadata.num_prefills == 0
and attn_metadata.num_decodes > 0
):
return self._forward_core_decode_non_spec(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
conv_state=conv_state,
ssm_state=ssm_state,
conv_weights=conv_weights,
)

if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
mixed_qkv_spec = mixed_qkv
Expand Down
Loading