Skip to content

Commit 7056450

Browse files
yaochengjikyuyeunk
andauthored
[Kernel] Remove KV masking by performing full bkv fetches in the first 2 steps (#1240)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 1916525 commit 7056450

File tree

3 files changed

+57
-23
lines changed

3 files changed

+57
-23
lines changed

tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def gen_random(shape, dtype):
9999
(0, 0),
100100
(0, 0),
101101
),
102-
constant_values=jnp.nan,
102+
constant_values=0,
103103
).reshape(
104104
-1,
105105
page_size,
@@ -122,7 +122,7 @@ def gen_random(shape, dtype):
122122
kv_cache,
123123
((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
124124
(0, 0)),
125-
constant_values=jnp.nan,
125+
constant_values=0,
126126
)
127127
page_indices = jnp.stack(page_indices_list, axis=0)
128128
page_indices = jnp.pad(

tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,12 @@ def _async_copy(src, dst, sem, wait):
456456
else:
457457
cp.start()
458458

459-
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
459+
def _fetch_bkv(seq_idx,
460+
bkv_idx,
461+
bkv_sem_idx,
462+
*,
463+
is_full_fetch=False,
464+
wait=False):
460465
sem = sems.at[0, bkv_sem_idx]
461466
vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
462467

@@ -534,10 +539,29 @@ def _fetch_bkv_from_new_kv():
534539
wait,
535540
)
536541

542+
# NOTE(chengjiyao): This condition is true for the first two bkv fetches.
543+
# We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
544+
# avoid potential NaN values in regions not overwritten by actual data.
545+
# This is done by padding the remaining parts of the buffer with data
546+
# from the KV cache. This special handling is only strictly necessary
547+
# until both buffers in the double buffer (bkv_x2_ref) have been written
548+
# to at least once.
549+
@pl.when(is_full_fetch)
550+
def _make_sure_bkv_vmem_is_not_nan():
551+
effective_sz = offset + bkv_sz_frm_new
552+
remaining_sz = bkv_sz - effective_sz
553+
_async_copy(
554+
cache_hbm_ref.at[pl.ds(0, remaining_sz)],
555+
vmem_ref.at[pl.ds(effective_sz, remaining_sz)],
556+
sem,
557+
wait,
558+
)
559+
537560
return kv_len_start + offset, bkv_sz_frm_new
538561
else:
539562
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
540-
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
563+
sz = lax.select(is_full_fetch, bkv_sz, offset + bkv_sz_frm_new)
564+
dst = vmem_ref.at[pl.ds(0, sz)]
541565
_async_copy(
542566
src=dst,
543567
dst=dst,
@@ -664,11 +688,18 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
664688
wait,
665689
)
666690

667-
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
668-
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
691+
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
692+
return _fetch_bkv(seq_idx,
693+
bkv_idx,
694+
bkv_sem_idx,
695+
is_full_fetch=is_full_fetch)
669696

670-
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
671-
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
697+
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
698+
return _fetch_bkv(seq_idx,
699+
bkv_idx,
700+
bkv_sem_idx,
701+
is_full_fetch=is_full_fetch,
702+
wait=True)
672703

673704
def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
674705
return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
@@ -726,7 +757,7 @@ def strided_load(ref, start, step):
726757
vec = ref[start::step]
727758
return vec
728759

729-
def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
760+
def strided_load_bkv(bkv_sem_idx, start, step):
730761
assert start % kv_packing == 0
731762
assert step % kv_packing == 0
732763
start //= kv_packing
@@ -735,7 +766,6 @@ def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
735766
bkv_sz * step, actual_head_dim_x2))
736767

737768
kv = strided_load(kv_ref, start, step)
738-
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
739769
bitwidth = 32 // kv_packing
740770
repack_ty = jnp.dtype(f"uint{bitwidth}")
741771
lst = []
@@ -809,31 +839,36 @@ def prefetch_next_bq():
809839
def compute_with_bkv(bkv_idx, _):
810840
# Create bitmask for KV.
811841
assert bkv_sz % kv_packing == 0
812-
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
813-
bkv_shape = (bkv_sz, actual_head_dim_x2)
814-
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
815-
0) < actual_bkv_sz
816842

817843
# Get next bkv ids.
818844
bkv_sem_idx = sem_ids_ref[1]
819-
next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
820-
seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
845+
next_seq_idx, next_bq_idx_for_kv, next_bkv_idx, next_bkv_sem_idx = (
846+
get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx))
821847

822848
# Prefetch next bkv
823849
@pl.when(next_seq_idx < num_seqs)
824850
def prefetch_next_bkv():
825851
sem_ids_ref[1] = next_bkv_sem_idx
826-
start_fetch_bkv(next_seq_idx, next_bkv_idx,
827-
next_bkv_sem_idx)
852+
start_fetch_bkv(
853+
next_seq_idx,
854+
next_bkv_idx,
855+
next_bkv_sem_idx,
856+
is_full_fetch=next_seq_idx + next_bq_idx_for_kv +
857+
next_bkv_idx < 2,
858+
)
828859

829860
# Wait for cur bq if not ready yet
830861
@pl.when(bkv_idx == bkv_idx_start)
831862
def wait_cur_bq():
832863
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
833864

834865
# Wait for cur bkv
835-
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
836-
bkv_sem_idx)
866+
offset, update_sz = wait_fetch_bkv(
867+
seq_idx,
868+
bkv_idx,
869+
bkv_sem_idx,
870+
is_full_fetch=seq_idx + bq_idx + bkv_idx < 2,
871+
)
837872

838873
# Start updating bkv to kv cache if applicable.
839874
# Only needed in first bq loop.
@@ -862,7 +897,6 @@ def update_cur_bkv_to_cache():
862897
bkv_sem_idx,
863898
kv_head_start,
864899
num_kv_heads,
865-
bkv_mask=bkv_mask,
866900
)
867901
assert len(bkv_lst) == kv_packing
868902
for i in range(kv_packing):
@@ -946,7 +980,7 @@ def update_cur_bkv_to_cache():
946980
@pl.when(seq_idx == 0)
947981
def prologue():
948982
start_fetch_bq(0, 0, 0)
949-
start_fetch_bkv(0, bkv_idx_start, 0)
983+
start_fetch_bkv(0, bkv_idx_start, 0, is_full_fetch=True)
950984

951985
@pl.when(seq_idx < decode_end)
952986
def process_decode():

tpu_inference/runner/kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def create_kv_caches(
8282
ShardingAxisName.ATTN_HEAD))
8383

8484
def _allocate() -> jax.Array:
85-
return jnp.empty(
85+
return jnp.zeros(
8686
shape=cache_shape,
8787
dtype=cache_dtype,
8888
)

0 commit comments

Comments
 (0)