Skip to content

Commit 4b74209

Browse files
authored
[RPA] Revert previous changes due to numeric issue (#1242)
1 parent 61b9cbe commit 4b74209

File tree

4 files changed

+24
-58
lines changed

4 files changed

+24
-58
lines changed

tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def gen_random(shape, dtype):
9999
(0, 0),
100100
(0, 0),
101101
),
102-
constant_values=0,
102+
constant_values=jnp.nan,
103103
).reshape(
104104
-1,
105105
page_size,
@@ -122,7 +122,7 @@ def gen_random(shape, dtype):
122122
kv_cache,
123123
((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
124124
(0, 0)),
125-
constant_values=0,
125+
constant_values=jnp.nan,
126126
)
127127
page_indices = jnp.stack(page_indices_list, axis=0)
128128
page_indices = jnp.pad(

tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py

Lines changed: 20 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,7 @@ def _async_copy(src, dst, sem, wait):
456456
else:
457457
cp.start()
458458

459-
def _fetch_bkv(seq_idx,
460-
bkv_idx,
461-
bkv_sem_idx,
462-
*,
463-
is_full_fetch=False,
464-
wait=False):
459+
def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False):
465460
sem = sems.at[0, bkv_sem_idx]
466461
vmem_ref = bkv_x2_ref.at[bkv_sem_idx]
467462

@@ -539,29 +534,10 @@ def _fetch_bkv_from_new_kv():
539534
wait,
540535
)
541536

542-
# NOTE(chengjiyao): This condition is true for the first two bkv fetches.
543-
# We need to ensure the bkv_x2_ref VMEM buffer is fully initialized to
544-
# avoid potential NaN values in regions not overwritten by actual data.
545-
# This is done by padding the remaining parts of the buffer with data
546-
# from the KV cache. This special handling is only strictly necessary
547-
# until both buffers in the double buffer (bkv_x2_ref) have been written
548-
# to at least once.
549-
@pl.when(is_full_fetch)
550-
def _make_sure_bkv_vmem_is_not_nan():
551-
effective_sz = offset + bkv_sz_frm_new
552-
remaining_sz = bkv_sz - effective_sz
553-
_async_copy(
554-
cache_hbm_ref.at[pl.ds(0, remaining_sz)],
555-
vmem_ref.at[pl.ds(effective_sz, remaining_sz)],
556-
sem,
557-
wait,
558-
)
559-
560537
return kv_len_start + offset, bkv_sz_frm_new
561538
else:
562539
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
563-
sz = lax.select(is_full_fetch, bkv_sz, offset + bkv_sz_frm_new)
564-
dst = vmem_ref.at[pl.ds(0, sz)]
540+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
565541
_async_copy(
566542
src=dst,
567543
dst=dst,
@@ -688,18 +664,11 @@ def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
688664
wait,
689665
)
690666

691-
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
692-
return _fetch_bkv(seq_idx,
693-
bkv_idx,
694-
bkv_sem_idx,
695-
is_full_fetch=is_full_fetch)
667+
def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
668+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
696669

697-
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, is_full_fetch=False):
698-
return _fetch_bkv(seq_idx,
699-
bkv_idx,
700-
bkv_sem_idx,
701-
is_full_fetch=is_full_fetch,
702-
wait=True)
670+
def wait_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
671+
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, wait=True)
703672

704673
def start_fetch_bq(seq_idx, bq_idx, bq_sem_idx):
705674
return _fetch_bq(seq_idx, bq_idx, bq_sem_idx)
@@ -757,7 +726,7 @@ def strided_load(ref, start, step):
757726
vec = ref[start::step]
758727
return vec
759728

760-
def strided_load_bkv(bkv_sem_idx, start, step):
729+
def strided_load_bkv(bkv_sem_idx, start, step, *, bkv_mask):
761730
assert start % kv_packing == 0
762731
assert step % kv_packing == 0
763732
start //= kv_packing
@@ -766,6 +735,7 @@ def strided_load_bkv(bkv_sem_idx, start, step):
766735
bkv_sz * step, actual_head_dim_x2))
767736

768737
kv = strided_load(kv_ref, start, step)
738+
kv = lax.select(bkv_mask, kv, jnp.zeros_like(kv))
769739
bitwidth = 32 // kv_packing
770740
repack_ty = jnp.dtype(f"uint{bitwidth}")
771741
lst = []
@@ -839,36 +809,31 @@ def prefetch_next_bq():
839809
def compute_with_bkv(bkv_idx, _):
840810
# Create bitmask for KV.
841811
assert bkv_sz % kv_packing == 0
812+
actual_bkv_sz = jnp.minimum(bkv_sz, kv_len - bkv_idx * bkv_sz)
813+
bkv_shape = (bkv_sz, actual_head_dim_x2)
814+
bkv_mask = lax.broadcasted_iota(jnp.int32, bkv_shape,
815+
0) < actual_bkv_sz
842816

843817
# Get next bkv ids.
844818
bkv_sem_idx = sem_ids_ref[1]
845-
next_seq_idx, next_bq_idx_for_kv, next_bkv_idx, next_bkv_sem_idx = (
846-
get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx))
819+
next_seq_idx, _, next_bkv_idx, next_bkv_sem_idx = get_next_bkv_ids(
820+
seq_idx, bq_idx, bkv_idx, bkv_sem_idx)
847821

848822
# Prefetch next bkv
849823
@pl.when(next_seq_idx < num_seqs)
850824
def prefetch_next_bkv():
851825
sem_ids_ref[1] = next_bkv_sem_idx
852-
start_fetch_bkv(
853-
next_seq_idx,
854-
next_bkv_idx,
855-
next_bkv_sem_idx,
856-
is_full_fetch=next_seq_idx + next_bq_idx_for_kv +
857-
next_bkv_idx < 2,
858-
)
826+
start_fetch_bkv(next_seq_idx, next_bkv_idx,
827+
next_bkv_sem_idx)
859828

860829
# Wait for cur bq if not ready yet
861830
@pl.when(bkv_idx == bkv_idx_start)
862831
def wait_cur_bq():
863832
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
864833

865834
# Wait for cur bkv
866-
offset, update_sz = wait_fetch_bkv(
867-
seq_idx,
868-
bkv_idx,
869-
bkv_sem_idx,
870-
is_full_fetch=seq_idx + bq_idx + bkv_idx < 2,
871-
)
835+
offset, update_sz = wait_fetch_bkv(seq_idx, bkv_idx,
836+
bkv_sem_idx)
872837

873838
# Start updating bkv to kv cache if applicable.
874839
# Only needed in first bq loop.
@@ -897,6 +862,7 @@ def update_cur_bkv_to_cache():
897862
bkv_sem_idx,
898863
kv_head_start,
899864
num_kv_heads,
865+
bkv_mask=bkv_mask,
900866
)
901867
assert len(bkv_lst) == kv_packing
902868
for i in range(kv_packing):
@@ -980,7 +946,7 @@ def update_cur_bkv_to_cache():
980946
@pl.when(seq_idx == 0)
981947
def prologue():
982948
start_fetch_bq(0, 0, 0)
983-
start_fetch_bkv(0, bkv_idx_start, 0, is_full_fetch=True)
949+
start_fetch_bkv(0, bkv_idx_start, 0)
984950

985951
@pl.when(seq_idx < decode_end)
986952
def process_decode():

tpu_inference/layers/common/attention_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def sharded_ragged_paged_attention(
312312
func = ragged_paged_attention
313313
if use_hd64:
314314
func = functools.partial(ragged_paged_attention_hd64,
315-
strict_sliding_window=False)
315+
strict_sliding_window=True)
316316
else:
317317
func = ragged_paged_attention
318318

tpu_inference/runner/kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def create_kv_caches(
9898
ShardingAxisName.ATTN_HEAD))
9999

100100
def _allocate() -> jax.Array:
101-
return jnp.zeros(
101+
return jnp.empty(
102102
shape=cache_shape,
103103
dtype=cache_dtype,
104104
)

0 commit comments

Comments
 (0)