Skip to content

Commit bf2feee

Browse files
committed
Refactor: Move CUDA graph dispatch logic earlier
Moves the CUDA graph dispatch logic to execute before the attention metadata is calculated within the dummy run. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent ecf8230 commit bf2feee

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,6 +3521,31 @@ def _dummy_run(
35213521
dp_rank = self.parallel_config.data_parallel_rank
35223522
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
35233523

3524+
# filter out the valid batch descriptor
3525+
_cg_mode, batch_descriptor = (
3526+
self.cudagraph_dispatcher.dispatch(
3527+
BatchDescriptor(
3528+
num_tokens=num_tokens_after_padding,
3529+
uniform_decode=uniform_decode,
3530+
has_lora=activate_lora and self.lora_config is not None,
3531+
)
3532+
)
3533+
if not is_profile
3534+
else (CUDAGraphMode.NONE, None)
3535+
)
3536+
if cudagraph_runtime_mode is not None:
3537+
# we allow forcing NONE when the dispatcher disagrees to support
3538+
# warm ups for cudagraph capture
3539+
assert (
3540+
cudagraph_runtime_mode == CUDAGraphMode.NONE
3541+
or cudagraph_runtime_mode == _cg_mode
3542+
), (
3543+
f"Cudagraph runtime mode mismatch at dummy_run. "
3544+
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
3545+
)
3546+
else:
3547+
cudagraph_runtime_mode = _cg_mode
3548+
35243549
attn_metadata: PerLayerAttnMetadata | None = None
35253550

35263551
# If force_attention is True, we always capture attention. Otherwise,
@@ -3595,31 +3620,6 @@ def _dummy_run(
35953620
num_tokens_after_padding, None, False
35963621
)
35973622

3598-
# filter out the valid batch descriptor
3599-
_cg_mode, batch_descriptor = (
3600-
self.cudagraph_dispatcher.dispatch(
3601-
BatchDescriptor(
3602-
num_tokens=num_tokens_after_padding,
3603-
uniform_decode=uniform_decode,
3604-
has_lora=activate_lora and self.lora_config is not None,
3605-
)
3606-
)
3607-
if not is_profile
3608-
else (CUDAGraphMode.NONE, None)
3609-
)
3610-
if cudagraph_runtime_mode is not None:
3611-
# we allow forcing NONE when the dispatcher disagrees to support
3612-
# warm ups for cudagraph capture
3613-
assert (
3614-
cudagraph_runtime_mode == CUDAGraphMode.NONE
3615-
or cudagraph_runtime_mode == _cg_mode
3616-
), (
3617-
f"Cudagraph runtime mode mismatch at dummy_run. "
3618-
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
3619-
)
3620-
else:
3621-
cudagraph_runtime_mode = _cg_mode
3622-
36233623
if ubatch_slices is not None:
36243624
# Adjust values to reflect a single ubatch.
36253625
# TODO(sage,lucas): this is cruft that should be addressed in

0 commit comments

Comments
 (0)