From 8fe971ff56de695969a6fe20f82497184e76f6ec Mon Sep 17 00:00:00 2001 From: rupengliu-meta Date: Wed, 3 Dec 2025 16:09:03 -0800 Subject: [PATCH] Remove a branch with pl.when in fetching bkv Signed-off-by: rupengliu-meta --- .../ragged_paged_attention/v3/kernel.py | 23 ++++++++----------- .../ragged_paged_attention/v3/kernel_hd64.py | 23 ++++++++----------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py b/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py index f10e7962e..fd4b2ee4e 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py @@ -463,19 +463,16 @@ def loop_body(i, offset): unroll=False, ) - # Fetch kv directly from new kv. - @pl.when(bkv_sz_frm_new > 0) - def _fetch_bkv_from_new_kv(): - new_kv_len_start = q_end - kv_left_frm_new - debug_print("[RPA debug] new_kv_len_start={}", - new_kv_len_start) - debug_print("[RPA debug] offset_in_bkv={}", offset) - _async_copy( - kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)], - vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)], - sem, - wait, - ) + size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0) + new_kv_len_start = q_end - kv_left_frm_new + debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start) + debug_print("[RPA debug] offset_in_bkv={}", offset) + _async_copy( + kv_hbm_ref.at[pl.ds(new_kv_len_start, size)], + vmem_ref.at[pl.ds(offset, size)], + sem, + wait, + ) return kv_len_start + offset, bkv_sz_frm_new else: diff --git a/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py b/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py index e94ee3579..bdcda357b 100644 --- a/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +++ b/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py @@ -520,19 +520,16 @@ def loop_body(i, offset): unroll=False, ) - # Fetch kv directly from new kv. - @pl.when(bkv_sz_frm_new > 0) - def _fetch_bkv_from_new_kv(): - new_kv_len_start = q_end - kv_left_frm_new - debug_print("[RPA debug] new_kv_len_start={}", - new_kv_len_start) - debug_print("[RPA debug] offset_in_bkv={}", offset) - _async_copy( - kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)], - vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)], - sem, - wait, - ) + size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0) + new_kv_len_start = q_end - kv_left_frm_new + debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start) + debug_print("[RPA debug] offset_in_bkv={}", offset) + _async_copy( + kv_hbm_ref.at[pl.ds(new_kv_len_start, size)], + vmem_ref.at[pl.ds(offset, size)], + sem, + wait, + ) return kv_len_start + offset, bkv_sz_frm_new else: