diff --git a/atom/model_ops/mamba_ops/causal_conv1d.py b/atom/model_ops/mamba_ops/causal_conv1d.py index 501d9db92..c99dea520 100644 --- a/atom/model_ops/mamba_ops/causal_conv1d.py +++ b/atom/model_ops/mamba_ops/causal_conv1d.py @@ -1205,6 +1205,7 @@ def causal_conv1d_update( block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, validate_data=False, + return_packed_qkv: bool = False, ): """ x: Input tensor which can take the following shapes: @@ -1289,9 +1290,20 @@ def causal_conv1d_update( # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' num_tokens = x.shape[0] - query = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device) - key = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device) - value = torch.empty([num_tokens, v_dim_size, 1], dtype=x.dtype, device=x.device) + packed_qkv = None + if return_packed_qkv: + packed_qkv = torch.empty( + [num_tokens, k_dim_size * 2 + v_dim_size], + dtype=x.dtype, + device=x.device, + ) + query = packed_qkv[:, :k_dim_size].unsqueeze(-1) + key = packed_qkv[:, k_dim_size : k_dim_size * 2].unsqueeze(-1) + value = packed_qkv[:, k_dim_size * 2 :].unsqueeze(-1) + else: + query = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device) + key = torch.empty([num_tokens, k_dim_size, 1], dtype=x.dtype, device=x.device) + value = torch.empty([num_tokens, v_dim_size, 1], dtype=x.dtype, device=x.device) stride_q_seq, stride_q_dim, stride_q_token = query.stride() stride_k_seq, stride_k_dim, stride_k_token = key.stride() @@ -1383,6 +1395,9 @@ def grid(META): ) if unsqueeze: out = out.squeeze(-1) + if return_packed_qkv: + return packed_qkv + query = query.squeeze(-1) key = key.squeeze(-1) value = value.squeeze(-1) diff --git a/atom/plugin/vllm/attention_backend/attention_gdn.py b/atom/plugin/vllm/attention_backend/attention_gdn.py index 32a29f1d3..d81b5bca2 100644 --- a/atom/plugin/vllm/attention_backend/attention_gdn.py +++ b/atom/plugin/vllm/attention_backend/attention_gdn.py @@ -12,12 +12,14 @@ ) # from atom.model_ops.attentions.gdn_attn import GDNAttentionMetadata +from vllm import envs as vllm_envs from vllm.forward_context import get_forward_context from vllm.distributed import get_tensor_model_parallel_world_size from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.model_executor.layers.fla.ops import ( chunk_gated_delta_rule as fla_chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode, ) from atom.model_ops.fla_ops.fused_sigmoid_gating import ( fused_sigmoid_gating_delta_rule_update, @@ -162,6 +164,9 @@ def __init__( self.head_k_dim = head_k_dim self.head_v_dim = head_v_dim self.chunk_gated_delta_rule = ChunkGatedDeltaRule() + self.enable_packed_recurrent_decode = ( + vllm_envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE + ) def rearrange_mixed_qkv(self, mixed_qkv): if mixed_qkv is None: @@ -182,6 +187,52 @@ def rearrange_mixed_qkv(self, mixed_qkv): value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query.contiguous(), key.contiguous(), value.contiguous() + def _forward_core_decode_non_spec( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + attn_metadata: GDNAttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + conv_weights: torch.Tensor, + ) -> torch.Tensor: + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor + num_actual_tokens = attn_metadata.num_actual_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv, + conv_state, + conv_weights, + self.num_k_heads * self.head_k_dim // self.tp_size, + self.num_v_heads * self.head_v_dim // self.tp_size, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + validate_data=False, + return_packed_qkv=True, + ) + + out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1) + fused_recurrent_gated_delta_rule_packed_decode( + mixed_qkv=mixed_qkv_non_spec, + a=a, + b=b, + A_log=self.A_log, + dt_bias=self.dt_bias, + scale=self.head_k_dim**-0.5, + initial_state=ssm_state, + out=out_buf, + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + use_qk_l2norm_in_kernel=True, + ) + return core_attn_out + def forward( self, mixed_qkv: torch.Tensor, @@ -217,6 +268,17 @@ def forward( ) # noqa: E501 compilation_config = forward_context.no_compile_layers self_kv_cache = compilation_config[layer_name].kv_cache + virtual_engine = getattr(forward_context, "virtual_engine", None) + # vLLM <= 0.17 exposed per-virtual-engine KV caches via + # forward_context.virtual_engine. vLLM 0.19 no longer sets that field + # for this path and the layer cache is already the active cache tuple. + if ( + virtual_engine is not None + and isinstance(self_kv_cache, (list, tuple)) + and self_kv_cache + and isinstance(self_kv_cache[0], (list, tuple)) + ): + self_kv_cache = self_kv_cache[virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens @@ -231,6 +293,23 @@ def forward( self.conv1d.weight.size(0), self.conv1d.weight.size(2) ) + if ( + self.enable_packed_recurrent_decode + and spec_sequence_masks is None + and attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes > 0 + ): + return self._forward_core_decode_non_spec( + mixed_qkv=mixed_qkv, + b=b, + a=a, + core_attn_out=core_attn_out, + attn_metadata=attn_metadata, + conv_state=conv_state, + ssm_state=ssm_state, + conv_weights=conv_weights, + ) + if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv