Skip to content

Commit bcf13cb

Browse files
yiz-liucharlotte12l
authored andcommitted
Refactor: Move CUDA graph dispatch logic earlier (vllm-project#27382)
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
1 parent a9682af commit bcf13cb

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3740,6 +3740,31 @@ def _dummy_run(
37403740
dp_rank = self.parallel_config.data_parallel_rank
37413741
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
37423742

3743+
# filter out the valid batch descriptor
3744+
_cg_mode, batch_descriptor = (
3745+
self.cudagraph_dispatcher.dispatch(
3746+
BatchDescriptor(
3747+
num_tokens=num_tokens_after_padding,
3748+
uniform_decode=uniform_decode,
3749+
has_lora=activate_lora and self.lora_config is not None,
3750+
)
3751+
)
3752+
if not is_profile
3753+
else (CUDAGraphMode.NONE, None)
3754+
)
3755+
if cudagraph_runtime_mode is not None:
3756+
# we allow forcing NONE when the dispatcher disagrees to support
3757+
# warm ups for cudagraph capture
3758+
assert (
3759+
cudagraph_runtime_mode == CUDAGraphMode.NONE
3760+
or cudagraph_runtime_mode == _cg_mode
3761+
), (
3762+
f"Cudagraph runtime mode mismatch at dummy_run. "
3763+
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
3764+
)
3765+
else:
3766+
cudagraph_runtime_mode = _cg_mode
3767+
37433768
attn_metadata: PerLayerAttnMetadata | None = None
37443769

37453770
# If force_attention is True, we always capture attention. Otherwise,
@@ -3814,31 +3839,6 @@ def _dummy_run(
38143839
num_tokens_after_padding, None, False
38153840
)
38163841

3817-
# filter out the valid batch descriptor
3818-
_cg_mode, batch_descriptor = (
3819-
self.cudagraph_dispatcher.dispatch(
3820-
BatchDescriptor(
3821-
num_tokens=num_tokens_after_padding,
3822-
uniform_decode=uniform_decode,
3823-
has_lora=activate_lora and self.lora_config is not None,
3824-
)
3825-
)
3826-
if not is_profile
3827-
else (CUDAGraphMode.NONE, None)
3828-
)
3829-
if cudagraph_runtime_mode is not None:
3830-
# we allow forcing NONE when the dispatcher disagrees to support
3831-
# warm ups for cudagraph capture
3832-
assert (
3833-
cudagraph_runtime_mode == CUDAGraphMode.NONE
3834-
or cudagraph_runtime_mode == _cg_mode
3835-
), (
3836-
f"Cudagraph runtime mode mismatch at dummy_run. "
3837-
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
3838-
)
3839-
else:
3840-
cudagraph_runtime_mode = _cg_mode
3841-
38423842
if ubatch_slices is not None:
38433843
# Adjust values to reflect a single ubatch.
38443844
# TODO(sage,lucas): this is cruft that should be addressed in

0 commit comments

Comments
 (0)