Skip to content

Commit 06bff73

Browse files
Remove a branch with pl.when in fetching bkv (#1239)
1 parent e3b52bf commit 06bff73

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

tpu_inference/kernels/ragged_paged_attention/v3/kernel.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -463,19 +463,16 @@ def loop_body(i, offset):
463463
unroll=False,
464464
)
465465

466-
# Fetch kv directly from new kv.
467-
@pl.when(bkv_sz_frm_new > 0)
468-
def _fetch_bkv_from_new_kv():
469-
new_kv_len_start = q_end - kv_left_frm_new
470-
debug_print("[RPA debug] new_kv_len_start={}",
471-
new_kv_len_start)
472-
debug_print("[RPA debug] offset_in_bkv={}", offset)
473-
_async_copy(
474-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476-
sem,
477-
wait,
478-
)
466+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
467+
new_kv_len_start = q_end - kv_left_frm_new
468+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
469+
debug_print("[RPA debug] offset_in_bkv={}", offset)
470+
_async_copy(
471+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
472+
vmem_ref.at[pl.ds(offset, size)],
473+
sem,
474+
wait,
475+
)
479476

480477
return kv_len_start + offset, bkv_sz_frm_new
481478
else:

tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -520,19 +520,16 @@ def loop_body(i, offset):
520520
unroll=False,
521521
)
522522

523-
# Fetch kv directly from new kv.
524-
@pl.when(bkv_sz_frm_new > 0)
525-
def _fetch_bkv_from_new_kv():
526-
new_kv_len_start = q_end - kv_left_frm_new
527-
debug_print("[RPA debug] new_kv_len_start={}",
528-
new_kv_len_start)
529-
debug_print("[RPA debug] offset_in_bkv={}", offset)
530-
_async_copy(
531-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
532-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
533-
sem,
534-
wait,
535-
)
523+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
524+
new_kv_len_start = q_end - kv_left_frm_new
525+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
526+
debug_print("[RPA debug] offset_in_bkv={}", offset)
527+
_async_copy(
528+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
529+
vmem_ref.at[pl.ds(offset, size)],
530+
sem,
531+
wait,
532+
)
536533

537534
return kv_len_start + offset, bkv_sz_frm_new
538535
else:

0 commit comments

Comments
 (0)