File tree Expand file tree Collapse file tree 2 files changed +20
-26
lines changed
tpu_inference/kernels/ragged_paged_attention/v3 Expand file tree Collapse file tree 2 files changed +20
-26
lines changed Original file line number Diff line number Diff line change @@ -463,19 +463,16 @@ def loop_body(i, offset):
463463 unroll = False ,
464464 )
465465
466- # Fetch kv directly from new kv.
467- @pl .when (bkv_sz_frm_new > 0 )
468- def _fetch_bkv_from_new_kv ():
469- new_kv_len_start = q_end - kv_left_frm_new
470- debug_print ("[RPA debug] new_kv_len_start={}" ,
471- new_kv_len_start )
472- debug_print ("[RPA debug] offset_in_bkv={}" , offset )
473- _async_copy (
474- kv_hbm_ref .at [pl .ds (new_kv_len_start , bkv_sz_frm_new )],
475- vmem_ref .at [pl .ds (offset , bkv_sz_frm_new )],
476- sem ,
477- wait ,
478- )
466+ size = lax .select (bkv_sz_frm_new > 0 , bkv_sz_frm_new , 0 )
467+ new_kv_len_start = q_end - kv_left_frm_new
468+ debug_print ("[RPA debug] new_kv_len_start={}" , new_kv_len_start )
469+ debug_print ("[RPA debug] offset_in_bkv={}" , offset )
470+ _async_copy (
471+ kv_hbm_ref .at [pl .ds (new_kv_len_start , size )],
472+ vmem_ref .at [pl .ds (offset , size )],
473+ sem ,
474+ wait ,
475+ )
479476
480477 return kv_len_start + offset , bkv_sz_frm_new
481478 else :
Original file line number Diff line number Diff line change @@ -520,19 +520,16 @@ def loop_body(i, offset):
520520 unroll = False ,
521521 )
522522
523- # Fetch kv directly from new kv.
524- @pl .when (bkv_sz_frm_new > 0 )
525- def _fetch_bkv_from_new_kv ():
526- new_kv_len_start = q_end - kv_left_frm_new
527- debug_print ("[RPA debug] new_kv_len_start={}" ,
528- new_kv_len_start )
529- debug_print ("[RPA debug] offset_in_bkv={}" , offset )
530- _async_copy (
531- kv_hbm_ref .at [pl .ds (new_kv_len_start , bkv_sz_frm_new )],
532- vmem_ref .at [pl .ds (offset , bkv_sz_frm_new )],
533- sem ,
534- wait ,
535- )
523+ size = lax .select (bkv_sz_frm_new > 0 , bkv_sz_frm_new , 0 )
524+ new_kv_len_start = q_end - kv_left_frm_new
525+ debug_print ("[RPA debug] new_kv_len_start={}" , new_kv_len_start )
526+ debug_print ("[RPA debug] offset_in_bkv={}" , offset )
527+ _async_copy (
528+ kv_hbm_ref .at [pl .ds (new_kv_len_start , size )],
529+ vmem_ref .at [pl .ds (offset , size )],
530+ sem ,
531+ wait ,
532+ )
536533
537534 return kv_len_start + offset , bkv_sz_frm_new
538535 else :
You can’t perform that action at this time.
0 commit comments