Skip to content

Commit 14f2ab3

Browse files
committed
[perf] Use direct copy (broadcast) instead of cat for k_nope/k_pe in MLA prefill
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent ea6e583 commit 14f2ab3

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,33 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
16451645
# Convert from (L, N, P) to (N, P, L)
16461646
self.W_UK_T = W_UK.permute(1, 2, 0)
16471647

1648+
def _concat_k_nope_k_pe(
1649+
self, k_nope: torch.Tensor, k_pe: torch.Tensor
1650+
) -> torch.Tensor:
1651+
"""
1652+
Efficiently concatenate k_nope and k_pe tensors along the last dimension.
1653+
1654+
This function avoids the performance penalty of torch.cat with expanded
1655+
non-contiguous tensors by pre-allocating the output and using direct copies.
1656+
1657+
Args:
1658+
k_nope: Tensor of shape [..., nope_dim]
1659+
k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim]
1660+
or [..., pe_dim]
1661+
1662+
Returns:
1663+
Tensor of shape [..., nope_dim + pe_dim]
1664+
"""
1665+
k = torch.empty(
1666+
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
1667+
dtype=k_nope.dtype,
1668+
device=k_nope.device,
1669+
)
1670+
# Direct copies with efficient broadcasting
1671+
k[..., : k_nope.shape[-1]] = k_nope
1672+
k[..., k_nope.shape[-1] :] = k_pe
1673+
return k
1674+
16481675
def _compute_prefill_context(
16491676
self,
16501677
q: torch.Tensor,
@@ -1681,7 +1708,7 @@ def _compute_prefill_context(
16811708
)
16821709
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
16831710

1684-
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1711+
k = self._concat_k_nope_k_pe(k_nope, k_pe)
16851712

16861713
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
16871714
prefill=prefill_metadata,
@@ -1785,7 +1812,7 @@ def _context_parallel_compute_prefill_context(
17851812
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
17861813
)
17871814
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1788-
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1815+
k = self._concat_k_nope_k_pe(k_nope, k_pe)
17891816

17901817
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
17911818
prefill=prefill_metadata,
@@ -1834,7 +1861,7 @@ def _forward_prefill(
18341861
)
18351862
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
18361863

1837-
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1864+
k = self._concat_k_nope_k_pe(k_nope, k_pe)
18381865

18391866
output_prefill = self._run_prefill_new_tokens(
18401867
prefill=attn_metadata.prefill,

0 commit comments

Comments
 (0)