Skip to content

Commit 4b696d1

Browse files
committed
Refactor: Defer dummy attention metadata creation
Moves the creation of attention metadata after the determination of `cudagraph_runtime_mode`. This ensures building attention metadata when replaying a CUDA graph. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 2566dca commit 4b696d1

File tree

1 file changed

+74
-74
lines changed

1 file changed

+74
-74
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,80 +3306,6 @@ def _dummy_run(
33063306
dp_rank = self.parallel_config.data_parallel_rank
33073307
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
33083308

3309-
attn_metadata: PerLayerAttnMetadata | None = None
3310-
3311-
# If force_attention is True, we always capture attention. Otherwise,
3312-
# it only happens for cudagraph_runtime_mode=FULL.
3313-
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
3314-
attn_metadata = {}
3315-
if ubatch_slices is not None:
3316-
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
3317-
3318-
if create_mixed_batch:
3319-
# In the mixed batch mode (used for FI warmup), we use
3320-
# shorter sequence lengths to run faster.
3321-
# TODO(luka) better system for describing dummy batches
3322-
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
3323-
else:
3324-
seq_lens = max_query_len
3325-
self.seq_lens.np[:num_reqs] = seq_lens
3326-
self.seq_lens.np[num_reqs:] = 0
3327-
self.seq_lens.copy_to_gpu()
3328-
3329-
cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
3330-
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
3331-
self.query_start_loc.copy_to_gpu()
3332-
3333-
for kv_cache_group_id, kv_cache_group_spec in enumerate(
3334-
self.kv_cache_config.kv_cache_groups
3335-
):
3336-
common_attn_metadata = CommonAttentionMetadata(
3337-
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1],
3338-
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1],
3339-
seq_lens=self.seq_lens.gpu[:num_reqs],
3340-
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
3341-
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
3342-
:num_reqs
3343-
],
3344-
num_reqs=num_reqs,
3345-
num_actual_tokens=num_tokens,
3346-
max_query_len=max_query_len,
3347-
max_seq_len=self.max_model_len,
3348-
block_table_tensor=self.input_batch.block_table[
3349-
kv_cache_group_id
3350-
].get_device_tensor(num_reqs),
3351-
slot_mapping=self.input_batch.block_table[
3352-
kv_cache_group_id
3353-
].slot_mapping.gpu[:num_tokens],
3354-
causal=True,
3355-
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
3356-
if self.dcp_world_size > 1
3357-
else None,
3358-
)
3359-
for attn_group in self.attn_groups[kv_cache_group_id]:
3360-
if ubatch_slices is not None:
3361-
common_attn_metadata_list = split_attn_metadata(
3362-
ubatch_slices, common_attn_metadata
3363-
)
3364-
for ubid, common_attn_metadata in enumerate(
3365-
common_attn_metadata_list
3366-
):
3367-
assert common_attn_metadata.max_query_len == 1
3368-
attn_metadata_i = attn_group.get_metadata_builder(
3369-
ubatch_id=ubid
3370-
).build_for_cudagraph_capture(common_attn_metadata)
3371-
for layer_name in attn_group.layer_names:
3372-
assert type(attn_metadata) is list
3373-
attn_metadata[ubid][layer_name] = attn_metadata_i
3374-
else:
3375-
assert type(attn_metadata) is dict
3376-
metadata_builder = attn_group.get_metadata_builder()
3377-
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
3378-
common_attn_metadata
3379-
)
3380-
for layer_name in attn_group.layer_names:
3381-
attn_metadata[layer_name] = attn_metadata_i
3382-
33833309
with self.maybe_dummy_run_with_lora(
33843310
self.lora_config, num_scheduled_tokens, activate_lora, remove_lora
33853311
):
@@ -3447,6 +3373,80 @@ def _dummy_run(
34473373
else:
34483374
cudagraph_runtime_mode = _cg_mode
34493375

3376+
attn_metadata: PerLayerAttnMetadata | None = None
3377+
3378+
# If force_attention is True, we always capture attention. Otherwise,
3379+
# it only happens for cudagraph_runtime_mode=FULL.
3380+
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
3381+
attn_metadata = {}
3382+
if ubatch_slices is not None:
3383+
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
3384+
3385+
if create_mixed_batch:
3386+
# In the mixed batch mode (used for FI warmup), we use
3387+
# shorter sequence lengths to run faster.
3388+
# TODO(luka) better system for describing dummy batches
3389+
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
3390+
else:
3391+
seq_lens = max_query_len
3392+
self.seq_lens.np[:num_reqs] = seq_lens
3393+
self.seq_lens.np[num_reqs:] = 0
3394+
self.seq_lens.copy_to_gpu()
3395+
3396+
cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
3397+
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
3398+
self.query_start_loc.copy_to_gpu()
3399+
3400+
for kv_cache_group_id, kv_cache_group_spec in enumerate(
3401+
self.kv_cache_config.kv_cache_groups
3402+
):
3403+
common_attn_metadata = CommonAttentionMetadata(
3404+
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1],
3405+
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1],
3406+
seq_lens=self.seq_lens.gpu[:num_reqs],
3407+
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
3408+
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
3409+
:num_reqs
3410+
],
3411+
num_reqs=num_reqs,
3412+
num_actual_tokens=num_tokens,
3413+
max_query_len=max_query_len,
3414+
max_seq_len=self.max_model_len,
3415+
block_table_tensor=self.input_batch.block_table[
3416+
kv_cache_group_id
3417+
].get_device_tensor(num_reqs),
3418+
slot_mapping=self.input_batch.block_table[
3419+
kv_cache_group_id
3420+
].slot_mapping.gpu[:num_tokens],
3421+
causal=True,
3422+
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
3423+
if self.dcp_world_size > 1
3424+
else None,
3425+
)
3426+
for attn_group in self.attn_groups[kv_cache_group_id]:
3427+
if ubatch_slices is not None:
3428+
common_attn_metadata_list = split_attn_metadata(
3429+
ubatch_slices, common_attn_metadata
3430+
)
3431+
for ubid, common_attn_metadata in enumerate(
3432+
common_attn_metadata_list
3433+
):
3434+
assert common_attn_metadata.max_query_len == 1
3435+
attn_metadata_i = attn_group.get_metadata_builder(
3436+
ubatch_id=ubid
3437+
).build_for_cudagraph_capture(common_attn_metadata)
3438+
for layer_name in attn_group.layer_names:
3439+
assert type(attn_metadata) is list
3440+
attn_metadata[ubid][layer_name] = attn_metadata_i
3441+
else:
3442+
assert type(attn_metadata) is dict
3443+
metadata_builder = attn_group.get_metadata_builder()
3444+
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
3445+
common_attn_metadata
3446+
)
3447+
for layer_name in attn_group.layer_names:
3448+
attn_metadata[layer_name] = attn_metadata_i
3449+
34503450
if ubatch_slices is not None:
34513451
# Adjust values to reflect a single ubatch.
34523452
# TODO(sage,lucas): this is cruft that should be addressed in

0 commit comments

Comments
 (0)