Skip to content

Commit bf54fc5

Browse files
authored
[RPA] Pipeline flash attention in default kernel (#1203)
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent 8f30493 commit bf54fc5

File tree

1 file changed

+72
-19
lines changed
  • tpu_inference/kernels/ragged_paged_attention/v3

1 file changed

+72
-19
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/kernel.py

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def debug_print(msg, *args):
319319
debug_print("[RPA debug] q_len={}", q_len)
320320
debug_print("[RPA debug] kv_len={}", kv_len)
321321

322-
def flash_attention(
322+
def flash_attention_step1_qk_softmax(
323323
q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
324324
k, # [bkv_sz, head_dim]
325325
v, # [bkv_sz, head_dim]
@@ -335,7 +335,6 @@ def flash_attention(
335335
assert k.dtype == v.dtype
336336
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
337337
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
338-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
339338

340339
def load_with_init(ref, init_val):
341340
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
@@ -376,15 +375,32 @@ def load_with_init(ref, init_val):
376375
head_m_ref[...] = m_curr
377376
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
378377

379-
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
380-
if v_scale is not None:
381-
pv *= v_scale
382-
383378
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
384379
exp_m_diff = jnp.exp(m_prev - m_curr)
385380
l_prev = load_with_init(head_l_ref, 0.0)
386381
l_curr = exp_m_diff * l_prev + p_rowsum
387382
head_l_ref[...] = l_curr
383+
384+
return p, exp_m_diff
385+
386+
def flash_attention_step2_pv(
387+
q_shape_0,
388+
v, # [bkv_sz, head_dim]
389+
p, # from step1
390+
exp_m_diff, # from step1
391+
*,
392+
bkv_idx,
393+
kv_head_idx,
394+
):
395+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
396+
397+
def load_with_init(ref, init_val):
398+
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
399+
ref[...])
400+
401+
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
402+
if v_scale is not None:
403+
pv *= v_scale
388404
o_prev = load_with_init(head_acc_ref, 0.0)
389405
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
390406
head_acc_ref[...] = o_curr
@@ -839,6 +855,11 @@ def update_cur_bkv_to_cache():
839855

840856
# Flash attention with cur bkv and bq
841857
# NOTE: kv_packing is divided by 2 because k and v are packed together.
858+
prev_bq_shape_0 = None
859+
prev_kv_head_bv = None
860+
prev_kv_head_idx = None
861+
prev_kv_head_p = None
862+
prev_kv_head_exp_m_diff = None
842863
heads_per_load = max(1, kv_packing // 2)
843864
for kv_head_start in range(0, actual_num_kv_heads,
844865
heads_per_load):
@@ -850,21 +871,53 @@ def update_cur_bkv_to_cache():
850871
)
851872
assert len(bkv_lst) == heads_per_load
852873
for i in range(heads_per_load):
853-
kv_head_idx = kv_head_start + i
854-
if kv_head_idx >= actual_num_kv_heads:
874+
cur_kv_head_idx = kv_head_start + i
875+
if cur_kv_head_idx >= actual_num_kv_heads:
855876
break
856-
bq = load_bq(bq_sem_idx,
857-
kv_head_idx,
858-
actual_bq_sz=actual_bq_sz)
877+
878+
cur_kv_head_bq = load_bq(bq_sem_idx,
879+
cur_kv_head_idx,
880+
actual_bq_sz=actual_bq_sz)
859881
bk, bv = bkv_lst[i]
860-
flash_attention(
861-
bq,
862-
bk,
863-
bv,
864-
bq_idx=bq_idx,
865-
bkv_idx=bkv_idx,
866-
kv_head_idx=kv_head_idx,
867-
)
882+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
883+
# and `flash_attention_step2_pv` to pipeline the computation.
884+
# `step2_pv` for the previous KV head, which depends on the softmax
885+
# output, is overlapped with `step1_qk_softmax` for the current KV
886+
# head, reducing overall wait times.
887+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
888+
flash_attention_step1_qk_softmax(
889+
cur_kv_head_bq,
890+
bk,
891+
bv,
892+
bq_idx=bq_idx,
893+
bkv_idx=bkv_idx,
894+
kv_head_idx=cur_kv_head_idx,
895+
))
896+
if prev_bq_shape_0 is not None:
897+
flash_attention_step2_pv(
898+
prev_bq_shape_0,
899+
prev_kv_head_bv,
900+
prev_kv_head_p,
901+
prev_kv_head_exp_m_diff,
902+
bkv_idx=bkv_idx,
903+
kv_head_idx=prev_kv_head_idx,
904+
)
905+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
906+
prev_kv_head_bv = bv
907+
prev_kv_head_p = cur_kv_head_p
908+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
909+
prev_kv_head_idx = cur_kv_head_idx
910+
911+
# Execute pv of last attention head.
912+
assert prev_bq_shape_0 is not None
913+
flash_attention_step2_pv(
914+
prev_bq_shape_0,
915+
prev_kv_head_bv,
916+
prev_kv_head_p,
917+
prev_kv_head_exp_m_diff,
918+
bkv_idx=bkv_idx,
919+
kv_head_idx=prev_kv_head_idx,
920+
)
868921

869922
lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
870923

0 commit comments

Comments
 (0)