Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 73 additions & 31 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def _ragged_paged_attention_kernel(
# TODO(jevinjiang): merge these into one so we can save SMEM.
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz)
# Input
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads_x2 // kv_packing, kv_packing, head_dim]
Expand Down Expand Up @@ -562,50 +563,89 @@ def loop_body(i, states):
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
sem = sems.at[1, bq_sem_idx]
vmem_ref = bq_x2_ref.at[bq_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bq_idx={}", bq_idx)
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bq_fetch_ids_ref[bq_sem_idx] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bq_fetch_ids_ref[bq_sem_idx]

debug_print("[RPA debug] sz (from scratch)={}", sz)
dst = vmem_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
sem = sems.at[2, bo_sem_idx]
vmem_ref = bo_x2_ref.at[bo_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bo_idx={}", bo_idx)
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bo_ids_ref[bo_sem_idx + 4] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bo_ids_ref[bo_sem_idx + 4]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = o_hbm_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
Expand Down Expand Up @@ -1445,10 +1485,12 @@ def ragged_paged_attention(
distribution,
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
jnp.zeros((3, ), jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
jnp.full((4, ), -1, jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bq_sem_0_sz, bq_sem_1_sz)
jnp.full((2, ), -1, jnp.int32),
)

scope_name = f"RPA-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
Expand Down Expand Up @@ -1487,8 +1529,8 @@ def ragged_paged_attention(
dtype=kv_cache.dtype),
],
input_output_aliases={
7: 0,
9: 1
8: 0,
10: 1
},
name=scope_name,
))
Expand Down
105 changes: 74 additions & 31 deletions tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def _ragged_paged_attention_kernel(
# TODO(jevinjiang): merge these into one so we can save SMEM.
distribution_ref, # [3] (decode_end, prefill_end, mixed_end)
sem_ids_ref, # [3] (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
bo_ids_ref, # [4] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
bo_ids_ref, # [6] (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
bkv_update_ids_ref, # [6] (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
bq_fetch_ids_ref, # [2] (bq_sem_0_sz, bq_sem_1_sz)
# Input
q_hbm_ref, # [actual_num_kv_heads, max_num_tokens, num_q_heads_per_kv_head // q_packing, q_packing, head_dim]
kv_hbm_ref, # [max_num_tokens, num_kv_heads // kv_packing, kv_packing, actual_head_dim_x2]
Expand Down Expand Up @@ -619,50 +620,90 @@ def loop_body(i, states):
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
sem = sems.at[1, bq_sem_idx]
vmem_ref = bq_x2_ref.at[bq_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_fetch_bq-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bq_idx={}", bq_idx)
debug_print("[RPA debug] bq_sem_idx={}", bq_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bq_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bq_fetch_ids_ref[bq_sem_idx] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
q_hbm_ref.at[:, pl.ds(q_len_start, sz)],
vmem_ref.at[:, pl.ds(0, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bq_fetch_ids_ref[bq_sem_idx]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = vmem_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def _send_bo(seq_idx, bo_idx, bo_sem_idx, *, wait=False):
sem = sems.at[2, bo_sem_idx]
vmem_ref = bo_x2_ref.at[bo_sem_idx]
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

debug_print(
"[RPA debug]"
f" -----------{'wait' if wait else 'start'}_send_bo-----------")
debug_print("[RPA debug] seq_idx={}", seq_idx)
debug_print("[RPA debug] bo_idx={}", bo_idx)
debug_print("[RPA debug] bo_sem_idx={}", bo_sem_idx)
debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)

if not wait:
# Calculate sz and store it in scratch
q_len_start = cu_q_lens_ref[seq_idx] + bo_idx * bq_sz
q_end = cu_q_lens_ref[seq_idx + 1]
sz = jnp.minimum(bq_sz, q_end - q_len_start)

# Store sz in scratch for later use
bo_ids_ref[bo_sem_idx + 4] = sz

debug_print("[RPA debug] q_len_start={}", q_len_start)
debug_print("[RPA debug] q_end={}", q_end)
debug_print("[RPA debug] sz={}", sz)

_async_copy(
vmem_ref.at[:, pl.ds(0, sz)],
o_hbm_ref.at[:, pl.ds(q_len_start, sz)],
sem,
wait,
)
else:
# Retrieve sz from scratch instead of recalculating
sz = bo_ids_ref[bo_sem_idx + 4]

debug_print("[RPA debug] sz (from scratch)={}", sz)

dst = o_hbm_ref.at[:, pl.ds(0, sz)]
_async_copy(
dst,
dst,
sem,
wait,
)

def start_fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx):
return _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx)
Expand Down Expand Up @@ -1511,10 +1552,12 @@ def ragged_paged_attention_hd64(
distribution,
# (bq_sem_idx, bkv_sem_idx, bo_sem_idx)
jnp.zeros((3, ), jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx)
jnp.full((4, ), -1, jnp.int32),
# (bo_sem_0_seq_idx, bo_sem_1_seq_idx, bo_sem_0_bo_idx, bo_sem_1_bo_idx, bo_sem_0_sz, bo_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bkv_sem_0_seq_idx, bkv_sem_1_seq_idx, bkv_sem_0_offset, bkv_sem_1_offset, bkv_sem_0_sz, bkv_sem_1_sz)
jnp.full((6, ), -1, jnp.int32),
# (bq_sem_0_sz, bq_sem_1_sz)
jnp.full((2, ), -1, jnp.int32),
)

scope_name = f"RPA-HD_64-bq_{bq_sz}-bkvp_{bkv_p}-p_{page_size}"
Expand Down Expand Up @@ -1554,8 +1597,8 @@ def ragged_paged_attention_hd64(
dtype=kv_cache.dtype),
],
input_output_aliases={
7: 0,
9: 1
8: 0,
10: 1
},
name=scope_name,
))
Expand Down