@@ -1645,6 +1645,33 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
16451645 # Convert from (L, N, P) to (N, P, L)
16461646 self .W_UK_T = W_UK .permute (1 , 2 , 0 )
16471647
1648+ def _concat_k_nope_k_pe (
1649+ self , k_nope : torch .Tensor , k_pe : torch .Tensor
1650+ ) -> torch .Tensor :
1651+ """
1652+ Efficiently concatenate k_nope and k_pe tensors along the last dimension.
1653+
1654+ This function avoids the performance penalty of torch.cat with expanded
1655+ non-contiguous tensors by pre-allocating the output and using direct copies.
1656+
1657+ Args:
1658+ k_nope: Tensor of shape [..., nope_dim]
1659+ k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim]
1660+ or [..., pe_dim]
1661+
1662+ Returns:
1663+ Tensor of shape [..., nope_dim + pe_dim]
1664+ """
1665+ k = torch .empty (
1666+ (* k_nope .shape [:- 1 ], k_nope .shape [- 1 ] + k_pe .shape [- 1 ]),
1667+ dtype = k_nope .dtype ,
1668+ device = k_nope .device ,
1669+ )
1670+ # Direct copies with efficient broadcasting
1671+ k [..., : k_nope .shape [- 1 ]] = k_nope
1672+ k [..., k_nope .shape [- 1 ] :] = k_pe
1673+ return k
1674+
16481675 def _compute_prefill_context (
16491676 self ,
16501677 q : torch .Tensor ,
@@ -1681,7 +1708,7 @@ def _compute_prefill_context(
16811708 )
16821709 k_nope , v = kv_nope .split ([self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
16831710
1684- k = torch . cat (( k_nope , k_pe . expand (( * k_nope . shape [: - 1 ], - 1 ))), dim = - 1 )
1711+ k = self . _concat_k_nope_k_pe ( k_nope , k_pe )
16851712
16861713 attn_output , attn_softmax_lse = self ._run_prefill_context_chunk (
16871714 prefill = prefill_metadata ,
@@ -1785,7 +1812,7 @@ def _context_parallel_compute_prefill_context(
17851812 - 1 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim
17861813 )
17871814 k_nope , v = kv_nope .split ([self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
1788- k = torch . cat (( k_nope , k_pe . expand (( * k_nope . shape [: - 1 ], - 1 ))), dim = - 1 )
1815+ k = self . _concat_k_nope_k_pe ( k_nope , k_pe )
17891816
17901817 attn_output , attn_softmax_lse = self ._run_prefill_context_chunk (
17911818 prefill = prefill_metadata ,
@@ -1834,7 +1861,7 @@ def _forward_prefill(
18341861 )
18351862 k_nope , v = kv_nope .split ([self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
18361863
1837- k = torch . cat (( k_nope , k_pe . expand (( * k_nope . shape [: - 1 ], - 1 ))), dim = - 1 )
1864+ k = self . _concat_k_nope_k_pe ( k_nope , k_pe )
18381865
18391866 output_prefill = self ._run_prefill_new_tokens (
18401867 prefill = attn_metadata .prefill ,
0 commit comments