Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 73 additions & 31 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def _ragged_paged_attention_kernel(
# TODO(jevinjiang): merge these into one so we can save SMEM.
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz)
# Input
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
Expand Down Expand Up @@ -562,50 +563,89 @@ def loop_body(i, states):
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
sem = sems.at[1, bq_sem_idx]
vmem_ref = bq_x2_ref.at[bq_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bq_idx={}", bq_idx)
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bq_fetch_ids_ref[bq_sem_idx] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bq_fetch_ids_ref[bq_sem_idx]

debug_print("[RPA debug] sz (from scratch)={}", sz)
dst = vmem_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
sem = sems.at[2, bo_sem_idx]
vmem_ref = bo_x2_ref.at[bo_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bo_idx={}", bo_idx)
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bo_ids_ref[bo_sem_idx + 4] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bo_ids_ref[bo_sem_idx + 4]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = o_hbm_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
Expand Down Expand Up @@ -1445,10 +1485,12 @@ def ragged_paged_attention(
distribution,
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
jnp.zeros((3, ), jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
jnp.full((4, ), -1, jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bq_sem_0_sz, bq_sem_1_sz)
jnp.full((2, ), -1, jnp.int32),
)

scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
Expand Down Expand Up @@ -1487,8 +1529,8 @@ def ragged_paged_attention(
dtype=kv_cache.dtype),
],
input_output_aliases={
7: 0,
9: 1
8: 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know why queries is not in the donate_argnames in jax.jit?

Copy link
Contributor

@rupeng-liu rupeng-liu Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it depends on if it will be used again after the attention. I took a quick look and didn't find where it is used again. So if no reuse of the queries, then we could donate it. @bythew3i thought?

10: 1
},
name=scope_name,
))
Expand Down
105 changes: 74 additions & 31 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def _ragged_paged_attention_kernel(
# TODO(jevinjiang): merge these into one so we can save SMEM.
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz)
# Input
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
Expand Down Expand Up @@ -619,50 +620,90 @@ def loop_body(i, states):
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
sem = sems.at[1, bq_sem_idx]
vmem_ref = bq_x2_ref.at[bq_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bq_idx={}", bq_idx)
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bq_fetch_ids_ref[bq_sem_idx] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bq_fetch_ids_ref[bq_sem_idx]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = vmem_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
sem = sems.at[2, bo_sem_idx]
vmem_ref = bo_x2_ref.at[bo_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bo_idx={}", bo_idx)
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bo_ids_ref[bo_sem_idx + 4] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bo_ids_ref[bo_sem_idx + 4]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = o_hbm_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
Expand Down Expand Up @@ -1511,10 +1552,12 @@ def ragged_paged_attention_hd64(
distribution,
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
jnp.zeros((3, ), jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
jnp.full((4, ), -1, jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bq_sem_0_sz, bq_sem_1_sz)
jnp.full((2, ), -1, jnp.int32),
)

scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
Expand Down Expand Up @@ -1554,8 +1597,8 @@ def ragged_paged_attention_hd64(
dtype=kv_cache.dtype),
],
input_output_aliases={
7: 0,
9: 1
8: 0,
10: 1
},
name=scope_name,
))
Expand Down