Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,19 +463,16 @@ def loop_body(i, offset):
unroll=False,
)

# Fetch kv directly from new kv.
@pl.when(bkv_sz_frm_new > 0)
def _fetch_bkv_from_new_kv():
new_kv_len_start = q_end - kv_left_frm_new
debug_print("[RPA debug] new_kv_len_start={}",
new_kv_len_start)
debug_print("[RPA debug] offset_in_bkv={}", offset)
_async_copy(
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
sem,
wait,
)
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if size==0, then the dma will be a no-op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly

new_kv_len_start = q_end - kv_left_frm_new
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
debug_print("[RPA debug] offset_in_bkv={}", offset)
_async_copy(
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
vmem_ref.at[pl.ds(offset, size)],
sem,
wait,
)

return kv_len_start + offset, bkv_sz_frm_new
else:
Expand Down
23 changes: 10 additions & 13 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,19 +520,16 @@ def loop_body(i, offset):
unroll=False,
)

# Fetch kv directly from new kv.
@pl.when(bkv_sz_frm_new > 0)
def _fetch_bkv_from_new_kv():
new_kv_len_start = q_end - kv_left_frm_new
debug_print("[RPA debug] new_kv_len_start={}",
new_kv_len_start)
debug_print("[RPA debug] offset_in_bkv={}", offset)
_async_copy(
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
sem,
wait,
)
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
new_kv_len_start = q_end - kv_left_frm_new
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
debug_print("[RPA debug] offset_in_bkv={}", offset)
_async_copy(
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
vmem_ref.at[pl.ds(offset, size)],
sem,
wait,
)

return kv_len_start + offset, bkv_sz_frm_new
else:
Expand Down