Skip to content

Commit ff4d31f

Browse files
committed
debug: add jax block
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent 6f8ae20 commit ff4d31f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tpu_inference/distributed/offload/tpu_offload_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,7 @@ def _bucketed_swap_out_fn(
13751375
# Fast path: handle bucket-sized transfers
13761376
if num_blocks in BLOCK_SIZE_BUCKETS:
13771377
flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu)
1378+
jax.block_until_ready(flat_kv_caches_cpu)
13781379
split_size_list = [self.block_size] * num_blocks
13791380
return [
13801381
jax.lax.split(flat_layer_cache, split_size_list, axis=0)
@@ -1405,6 +1406,7 @@ def _bucketed_swap_out_fn(
14051406
# Swap the bucket to CPU, result is a flat tensor for this bucket. We are doing the chunking inside this function to avoid returning any jnp.concatenate
14061407
# of kv cache for the the bucketed blocks
14071408
cpu_chunk_flat_per_layer = self.swap_out_fn(tpu_chunk)
1409+
jax.block_until_ready(cpu_chunk_flat_per_layer)
14081410
# Split the flat bucket tensor into block-sized chunks and append
14091411
split_size_list = [self.block_size] * decomposed_block_size
14101412
for i, layer_cache in enumerate(cpu_chunk_flat_per_layer):

0 commit comments

Comments
 (0)