@@ -456,7 +456,12 @@ def _async_copy(src, dst, sem, wait):
456456 else :
457457 cp .start ()
458458
459- def _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , wait = False ):
459+ def _fetch_bkv (seq_idx ,
460+ bkv_idx ,
461+ bkv_sem_idx ,
462+ * ,
463+ is_full_fetch = False ,
464+ wait = False ):
460465 sem = sems .at [0 , bkv_sem_idx ]
461466 vmem_ref = bkv_x2_ref .at [bkv_sem_idx ]
462467
@@ -534,10 +539,29 @@ def _fetch_bkv_from_new_kv():
534539 wait ,
535540 )
536541
542+ # NOTE(chengjiyao): This condition is true for the first two bkv fetches.
543+ # We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
544+ # avoid potential NaN values in regions not overwritten by actual data.
545+ # This is done by padding the remaining parts of the buffer with data
546+ # from the KV cache. This special handling is only strictly necessary
547+ # until both buffers in the double buffer (bkv_x2_ref) have been written
548+ # to at least once.
549+ @pl .when (is_full_fetch )
550+ def _make_sure_bkv_vmem_is_not_nan ():
551+ effective_sz = offset + bkv_sz_frm_new
552+ remaining_sz = bkv_sz - effective_sz
553+ _async_copy (
554+ cache_hbm_ref .at [pl .ds (0 , remaining_sz )],
555+ vmem_ref .at [pl .ds (effective_sz , remaining_sz )],
556+ sem ,
557+ wait ,
558+ )
559+
537560 return kv_len_start + offset , bkv_sz_frm_new
538561 else :
539562 offset = jnp .minimum (kv_left_frm_cache , page_size * bkv_p )
540- dst = vmem_ref .at [pl .ds (0 , offset + bkv_sz_frm_new )]
563+ sz = lax .select (is_full_fetch , bkv_sz , offset + bkv_sz_frm_new )
564+ dst = vmem_ref .at [pl .ds (0 , sz )]
541565 _async_copy (
542566 src = dst ,
543567 dst = dst ,
@@ -664,11 +688,18 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
664688 wait ,
665689 )
666690
667- def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
668- return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx )
691+ def start_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , is_full_fetch = False ):
692+ return _fetch_bkv (seq_idx ,
693+ bkv_idx ,
694+ bkv_sem_idx ,
695+ is_full_fetch = is_full_fetch )
669696
670- def wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx ):
671- return _fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , wait = True )
697+ def wait_fetch_bkv (seq_idx , bkv_idx , bkv_sem_idx , * , is_full_fetch = False ):
698+ return _fetch_bkv (seq_idx ,
699+ bkv_idx ,
700+ bkv_sem_idx ,
701+ is_full_fetch = is_full_fetch ,
702+ wait = True )
672703
673704 def start_fetch_bq (seq_idx , bq_idx , bq_sem_idx ):
674705 return _fetch_bq (seq_idx , bq_idx , bq_sem_idx )
@@ -726,7 +757,7 @@ def strided_load(ref, start, step):
726757 vec = ref [start ::step ]
727758 return vec
728759
729- def strided_load_bkv (bkv_sem_idx , start , step , * , bkv_mask ):
760+ def strided_load_bkv (bkv_sem_idx , start , step ):
730761 assert start % kv_packing == 0
731762 assert step % kv_packing == 0
732763 start //= kv_packing
@@ -735,7 +766,6 @@ def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
735766 bkv_sz * step , actual_head_dim_x2 ))
736767
737768 kv = strided_load (kv_ref , start , step )
738- kv = lax .select (bkv_mask , kv , jnp .zeros_like (kv ))
739769 bitwidth = 32 // kv_packing
740770 repack_ty = jnp .dtype (f"uint{ bitwidth } " )
741771 lst = []
@@ -809,31 +839,36 @@ def prefetch_next_bq():
809839 def compute_with_bkv (bkv_idx , _ ):
810840 # Create bitmask for KV.
811841 assert bkv_sz % kv_packing == 0
812- actual_bkv_sz = jnp .minimum (bkv_sz , kv_len - bkv_idx * bkv_sz )
813- bkv_shape = (bkv_sz , actual_head_dim_x2 )
814- bkv_mask = lax .broadcasted_iota (jnp .int32 , bkv_shape ,
815- 0 ) < actual_bkv_sz
816842
817843 # Get next bkv ids.
818844 bkv_sem_idx = sem_ids_ref [1 ]
819- next_seq_idx , _ , next_bkv_idx , next_bkv_sem_idx = get_next_bkv_ids (
820- seq_idx , bq_idx , bkv_idx , bkv_sem_idx )
845+ next_seq_idx , next_bq_idx_for_kv , next_bkv_idx , next_bkv_sem_idx = (
846+ get_next_bkv_ids ( seq_idx , bq_idx , bkv_idx , bkv_sem_idx ) )
821847
822848 # Prefetch next bkv
823849 @pl .when (next_seq_idx < num_seqs )
824850 def prefetch_next_bkv ():
825851 sem_ids_ref [1 ] = next_bkv_sem_idx
826- start_fetch_bkv (next_seq_idx , next_bkv_idx ,
827- next_bkv_sem_idx )
852+ start_fetch_bkv (
853+ next_seq_idx ,
854+ next_bkv_idx ,
855+ next_bkv_sem_idx ,
856+ is_full_fetch = next_seq_idx + next_bq_idx_for_kv +
857+ next_bkv_idx < 2 ,
858+ )
828859
829860 # Wait for cur bq if not ready yet
830861 @pl .when (bkv_idx == bkv_idx_start )
831862 def wait_cur_bq ():
832863 wait_fetch_bq (seq_idx , bq_idx , bq_sem_idx )
833864
834865 # Wait for cur bkv
835- offset , update_sz = wait_fetch_bkv (seq_idx , bkv_idx ,
836- bkv_sem_idx )
866+ offset , update_sz = wait_fetch_bkv (
867+ seq_idx ,
868+ bkv_idx ,
869+ bkv_sem_idx ,
870+ is_full_fetch = seq_idx + bq_idx + bkv_idx < 2 ,
871+ )
837872
838873 # Start updating bkv to kv cache if applicable.
839874 # Only needed in first bq loop.
@@ -862,7 +897,6 @@ def update_cur_bkv_to_cache():
862897 bkv_sem_idx ,
863898 kv_head_start ,
864899 num_kv_heads ,
865- bkv_mask = bkv_mask ,
866900 )
867901 assert len (bkv_lst ) == kv_packing
868902 for i in range (kv_packing ):
@@ -946,7 +980,7 @@ def update_cur_bkv_to_cache():
946980 @pl .when (seq_idx == 0 )
947981 def prologue ():
948982 start_fetch_bq (0 , 0 , 0 )
949- start_fetch_bkv (0 , bkv_idx_start , 0 )
983+ start_fetch_bkv (0 , bkv_idx_start , 0 , is_full_fetch = True )
950984
951985 @pl .when (seq_idx < decode_end )
952986 def process_decode ():
0 commit comments