153153from vllm .v1 .worker .ubatch_utils import (
154154 UBatchSlices ,
155155 check_ubatch_thresholds ,
156+ maybe_create_ubatch_slices ,
156157)
157158from vllm .v1 .worker .utils import is_residual_scattered_for_sp
158159
@@ -2743,7 +2744,7 @@ def _determine_batch_execution_and_padding(
27432744 ) -> tuple [
27442745 CUDAGraphMode ,
27452746 BatchDescriptor ,
2746- UBatchSlices | None ,
2747+ bool ,
27472748 torch .Tensor | None ,
27482749 CUDAGraphStat | None ,
27492750 ]:
@@ -2779,7 +2780,7 @@ def _determine_batch_execution_and_padding(
27792780
27802781 # Extra coordination when running data-parallel since we need to coordinate
27812782 # across ranks
2782- ubatch_slices , num_tokens_across_dp = None , None
2783+ should_ubatch , num_tokens_across_dp = False , None
27832784 if self .vllm_config .parallel_config .data_parallel_size > 1 :
27842785 # Disable DP padding when running eager to avoid excessive padding when
27852786 # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
@@ -2789,8 +2790,8 @@ def _determine_batch_execution_and_padding(
27892790 self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE
27902791 )
27912792
2792- ubatch_slices , num_tokens_across_dp = coordinate_batch_across_dp (
2793- num_tokens_unpadded = num_tokens_padded ,
2793+ should_ubatch , num_tokens_across_dp = coordinate_batch_across_dp (
2794+ num_tokens_unpadded = num_tokens ,
27942795 parallel_config = self .parallel_config ,
27952796 allow_microbatching = allow_microbatching ,
27962797 allow_dp_padding = allow_dp_padding ,
@@ -2822,7 +2823,7 @@ def _determine_batch_execution_and_padding(
28222823 return (
28232824 cudagraph_mode ,
28242825 batch_descriptor ,
2825- ubatch_slices ,
2826+ should_ubatch ,
28262827 num_tokens_across_dp ,
28272828 cudagraph_stats ,
28282829 )
@@ -2921,7 +2922,7 @@ def execute_model(
29212922 (
29222923 cudagraph_mode ,
29232924 batch_desc ,
2924- ubatch_slices ,
2925+ should_ubatch ,
29252926 num_tokens_across_dp ,
29262927 cudagraph_stats ,
29272928 ) = self ._determine_batch_execution_and_padding (
@@ -2934,29 +2935,37 @@ def execute_model(
29342935
29352936 logger .debug (
29362937 "Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
2937- "ubatch_slices : %s, num_tokens_across_dp: %s" ,
2938+ "should_ubatch : %s, num_tokens_across_dp: %s" ,
29382939 cudagraph_mode ,
29392940 batch_desc ,
2940- ubatch_slices ,
2941+ should_ubatch ,
29412942 num_tokens_across_dp ,
29422943 )
29432944
29442945 num_tokens_padded = batch_desc .num_tokens
29452946 num_reqs_padded = (
29462947 batch_desc .num_reqs if batch_desc .num_reqs is not None else num_reqs
29472948 )
2949+ ubatch_slices , ubatch_slices_padded = maybe_create_ubatch_slices (
2950+ should_ubatch ,
2951+ num_scheduled_tokens_np ,
2952+ num_tokens_padded ,
2953+ num_reqs_padded ,
2954+ )
29482955
2949- use_spec_decode = len (scheduler_output .scheduled_spec_decode_tokens ) > 0
29502956 pad_attn = cudagraph_mode == CUDAGraphMode .FULL
29512957
2958+ use_spec_decode = len (scheduler_output .scheduled_spec_decode_tokens ) > 0
2959+ ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
2960+
29522961 (attn_metadata , spec_decode_common_attn_metadata ) = (
29532962 self ._build_attention_metadata (
29542963 num_tokens = num_tokens_unpadded ,
29552964 num_tokens_padded = num_tokens_padded if pad_attn else None ,
29562965 num_reqs = num_reqs ,
29572966 num_reqs_padded = num_reqs_padded if pad_attn else None ,
29582967 max_query_len = max_num_scheduled_tokens ,
2959- ubatch_slices = ubatch_slices ,
2968+ ubatch_slices = ubatch_slices_attn ,
29602969 logits_indices = logits_indices ,
29612970 use_spec_decode = use_spec_decode ,
29622971 num_scheduled_tokens = scheduler_output .num_scheduled_tokens ,
@@ -2993,7 +3002,7 @@ def execute_model(
29933002 num_tokens_across_dp = num_tokens_across_dp ,
29943003 cudagraph_runtime_mode = cudagraph_mode ,
29953004 batch_descriptor = batch_desc ,
2996- ubatch_slices = ubatch_slices ,
3005+ ubatch_slices = ubatch_slices_padded ,
29973006 ),
29983007 record_function_or_nullcontext ("gpu_model_runner: forward" ),
29993008 self .maybe_get_kv_connector_output (scheduler_output ) as kv_connector_output ,
@@ -3945,7 +3954,7 @@ def _dummy_run(
39453954
39463955 num_sampled_tokens = np .ones (num_reqs , dtype = np .int32 )
39473956
3948- _cudagraph_mode , batch_desc , ubatch_slices , num_tokens_across_dp , _ = (
3957+ _cudagraph_mode , batch_desc , should_ubatch , num_tokens_across_dp , _ = (
39493958 self ._determine_batch_execution_and_padding (
39503959 num_tokens = num_tokens_unpadded ,
39513960 num_reqs = num_reqs ,
@@ -3979,6 +3988,9 @@ def _dummy_run(
39793988 num_reqs_padded = (
39803989 batch_desc .num_reqs if batch_desc .num_reqs is not None else num_reqs
39813990 )
3991+ ubatch_slices , ubatch_slices_padded = maybe_create_ubatch_slices (
3992+ should_ubatch , num_scheduled_tokens , num_tokens_padded , num_reqs_padded
3993+ )
39823994
39833995 attn_metadata : PerLayerAttnMetadata | None = None
39843996
@@ -4000,11 +4012,12 @@ def _dummy_run(
40004012 self .query_start_loc .np [1 : num_reqs + 1 ] = cum_num_tokens
40014013 self .query_start_loc .copy_to_gpu ()
40024014
4015+ pad_attn = cudagraph_runtime_mode == CUDAGraphMode .FULL
40034016 attn_metadata , _ = self ._build_attention_metadata (
40044017 num_tokens = num_tokens_unpadded ,
40054018 num_reqs = num_reqs_padded ,
40064019 max_query_len = max_query_len ,
4007- ubatch_slices = ubatch_slices ,
4020+ ubatch_slices = ubatch_slices_padded if pad_attn else ubatch_slices ,
40084021 for_cudagraph_capture = is_graph_capturing ,
40094022 )
40104023
@@ -4056,11 +4069,11 @@ def _dummy_run(
40564069 num_tokens_padded , None , False
40574070 )
40584071
4059- if ubatch_slices is not None :
4072+ if ubatch_slices_padded is not None :
40604073 # Adjust values to reflect a single ubatch.
40614074 # TODO(sage,lucas): this is cruft that should be addressed in
40624075 # the padding refactor.
4063- num_tokens_padded = ubatch_slices [0 ].num_tokens
4076+ num_tokens_padded = ubatch_slices_padded [0 ].num_tokens
40644077 if num_tokens_across_dp is not None :
40654078 num_tokens_across_dp [:] = num_tokens_padded
40664079
@@ -4073,7 +4086,7 @@ def _dummy_run(
40734086 num_tokens_across_dp = num_tokens_across_dp ,
40744087 cudagraph_runtime_mode = cudagraph_runtime_mode ,
40754088 batch_descriptor = batch_desc ,
4076- ubatch_slices = ubatch_slices ,
4089+ ubatch_slices = ubatch_slices_padded ,
40774090 ),
40784091 ):
40794092 outputs = self .model (
0 commit comments