@@ -456,12 +456,7 @@ def _async_copy(src, dst, sem, wait):
456456 else :
457457 cp .start ()
458458
459- def _fetch_bkv (seq_idx ,
460- bkv_idx ,
461- bkv_sem_idx ,
462- * ,
463- is_full_fetch = False ,
464- wait = False ):
459+ def _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , wait = False ):
465460 sem = sems .at [0 , bkv_sem_idx ]
466461 vmem_ref = bkv_x2_ref .at [bkv_sem_idx ]
467462
@@ -539,29 +534,10 @@ def _fetch_bkv_from_new_kv():
539534 wait ,
540535 )
541536
542- # NOTE(chengjiyao): This condition is true for the first two bkv fetches.
543- # We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
544- # avoid potential NaN values in regions not overwritten by actual data.
545- # This is done by padding the remaining parts of the buffer with data
546- # from the KV cache. This special handling is only strictly necessary
547- # until both buffers in the double buffer (bkv_x2_ref) have been written
548- # to at least once.
549- @pl .when (is_full_fetch )
550- def _make_sure_bkv_vmem_is_not_nan ():
551- effective_sz = offset + bkv_sz_frm_new
552- remaining_sz = bkv_sz - effective_sz
553- _async_copy (
554- cache_hbm_ref .at [pl .ds (0 , remaining_sz )],
555- vmem_ref .at [pl .ds (effective_sz , remaining_sz )],
556- sem ,
557- wait ,
558- )
559-
560537 return kv_len_start + offset , bkv_sz_frm_new
561538 else :
562539 offset = jnp .minimum (kv_left_frm_cache , page_size * bkv_p )
563- sz = lax .select (is_full_fetch , bkv_sz , offset + bkv_sz_frm_new )
564- dst = vmem_ref .at [pl .ds (0 , sz )]
540+ dst = vmem_ref .at [pl .ds (0 , offset + bkv_sz_frm_new )]
565541 _async_copy (
566542 src = dst ,
567543 dst = dst ,
@@ -688,18 +664,11 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
688664 wait ,
689665 )
690666
691- def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , is_full_fetch = False ):
692- return _fetch_bkv (seq_idx ,
693- bkv_idx ,
694- bkv_sem_idx ,
695- is_full_fetch = is_full_fetch )
667+ def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
668+ return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
696669
697- def wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , is_full_fetch = False ):
698- return _fetch_bkv (seq_idx ,
699- bkv_idx ,
700- bkv_sem_idx ,
701- is_full_fetch = is_full_fetch ,
702- wait = True )
670+ def wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
671+ return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , wait = True )
703672
704673 def start_fetch_bq (seq_idx , bq_idx , bq_sem_idx ):
705674 return _fetch_bq (seq_idx , bq_idx , bq_sem_idx )
@@ -757,7 +726,7 @@ def strided_load(ref, start, step):
757726 vec = ref [start ::step ]
758727 return vec
759728
760- def strided_load_bkv (bkv_sem_idx , start , step ):
729+ def strided_load_bkv (bkv_sem_idx , start , step , * , bkv_mask ):
761730 assert start % kv_packing == 0
762731 assert step % kv_packing == 0
763732 start //= kv_packing
@@ -766,6 +735,7 @@ def strided_load_bkv(bkv_sem_idx, start, step):
766735 bkv_sz * step , actual_head_dim_x2 ))
767736
768737 kv = strided_load (kv_ref , start , step )
738+ kv = lax .select (bkv_mask , kv , jnp .zeros_like (kv ))
769739 bitwidth = 32 // kv_packing
770740 repack_ty = jnp .dtype (f"uint{ bitwidth } " )
771741 lst = []
@@ -839,36 +809,31 @@ def prefetch_next_bq():
839809 def compute_with_bkv (bkv_idx , _ ):
840810 # Create bitmask for KV.
841811 assert bkv_sz % kv_packing == 0
812+ actual_bkv_sz = jnp .minimum (bkv_sz , kv_len - bkv_idx * bkv_sz )
813+ bkv_shape = (bkv_sz , actual_head_dim_x2 )
814+ bkv_mask = lax .broadcasted_iota (jnp .int32 , bkv_shape ,
815+ 0 ) < actual_bkv_sz
842816
843817 # Get next bkv ids.
844818 bkv_sem_idx = sem_ids_ref [1 ]
845- next_seq_idx , next_bq_idx_for_kv , next_bkv_idx , next_bkv_sem_idx = (
846- get_next_bkv_ids ( seq_idx , bq_idx , bkv_idx , bkv_sem_idx ) )
819+ next_seq_idx , _ , next_bkv_idx , next_bkv_sem_idx = get_next_bkv_ids (
820+ seq_idx , bq_idx , bkv_idx , bkv_sem_idx )
847821
848822 # Prefetch next bkv
849823 @pl .when (next_seq_idx < num_seqs )
850824 def prefetch_next_bkv ():
851825 sem_ids_ref [1 ] = next_bkv_sem_idx
852- start_fetch_bkv (
853- next_seq_idx ,
854- next_bkv_idx ,
855- next_bkv_sem_idx ,
856- is_full_fetch = next_seq_idx + next_bq_idx_for_kv +
857- next_bkv_idx < 2 ,
858- )
826+ start_fetch_bkv (next_seq_idx , next_bkv_idx ,
827+ next_bkv_sem_idx )
859828
860829 # Wait for cur bq if not ready yet
861830 @pl .when (bkv_idx == bkv_idx_start )
862831 def wait_cur_bq ():
863832 wait_fetch_bq (seq_idx , bq_idx , bq_sem_idx )
864833
865834 # Wait for cur bkv
866- offset , update_sz = wait_fetch_bkv (
867- seq_idx ,
868- bkv_idx ,
869- bkv_sem_idx ,
870- is_full_fetch = seq_idx + bq_idx + bkv_idx < 2 ,
871- )
835+ offset , update_sz = wait_fetch_bkv (seq_idx , bkv_idx ,
836+ bkv_sem_idx )
872837
873838 # Start updating bkv to kv cache if applicable.
874839 # Only needed in first bq loop.
@@ -897,6 +862,7 @@ def update_cur_bkv_to_cache():
897862 bkv_sem_idx ,
898863 kv_head_start ,
899864 num_kv_heads ,
865+ bkv_mask = bkv_mask ,
900866 )
901867 assert len (bkv_lst ) == kv_packing
902868 for i in range (kv_packing ):
@@ -980,7 +946,7 @@ def update_cur_bkv_to_cache():
980946 @pl .when (seq_idx == 0 )
981947 def prologue ():
982948 start_fetch_bq (0 , 0 , 0 )
983- start_fetch_bkv (0 , bkv_idx_start , 0 , is_full_fetch = True )
949+ start_fetch_bkv (0 , bkv_idx_start , 0 )
984950
985951 @pl .when (seq_idx < decode_end )
986952 def process_decode ():
0 commit comments