@@ -70,17 +70,12 @@ def _create_padded_batch_descriptor(
7070 uniform_decode_query_len = self .uniform_decode_query_len
7171 num_tokens_padded = self .vllm_config .pad_for_cudagraph (num_tokens )
7272
73- if uniform_decode :
73+ if uniform_decode and self . cudagraph_mode . has_mode ( CUDAGraphMode . FULL ) :
7474 num_reqs = num_tokens_padded // uniform_decode_query_len
7575 assert num_tokens_padded % uniform_decode_query_len == 0
7676 assert num_reqs <= max_num_seqs
77- return BatchDescriptor (
78- num_tokens = num_tokens_padded ,
79- num_reqs = num_reqs ,
80- uniform = uniform_decode ,
81- has_lora = has_lora ,
82- )
83- num_reqs = min (num_tokens_padded , max_num_seqs )
77+ else :
78+ num_reqs = min (num_tokens_padded , max_num_seqs )
8479
8580 return BatchDescriptor (
8681 num_tokens = num_tokens_padded ,
@@ -168,24 +163,24 @@ def dispatch(
168163 ):
169164 return CUDAGraphMode .NONE , BatchDescriptor (num_tokens )
170165
171- batch_descriptor = self ._create_padded_batch_descriptor (
166+ batch_desc = self ._create_padded_batch_descriptor (
172167 num_tokens , uniform_decode , has_lora
173168 )
174- relaxed_batch_descriptor = batch_descriptor .relax_for_mixed_batch_cudagraphs ()
169+ relaxed_batch_desc = batch_desc .relax_for_mixed_batch_cudagraphs ()
175170
176171 if not use_cascade_attn :
177172 # check if key exists for full cudagraph
178- if batch_descriptor in self .cudagraph_keys [CUDAGraphMode .FULL ]:
179- return CUDAGraphMode .FULL , batch_descriptor
173+ if batch_desc in self .cudagraph_keys [CUDAGraphMode .FULL ]:
174+ return CUDAGraphMode .FULL , batch_desc
180175
181176 # otherwise, check if the relaxed key exists
182- if relaxed_batch_descriptor in self .cudagraph_keys [CUDAGraphMode .FULL ]:
183- return CUDAGraphMode .FULL , relaxed_batch_descriptor
177+ if relaxed_batch_desc in self .cudagraph_keys [CUDAGraphMode .FULL ]:
178+ return CUDAGraphMode .FULL , relaxed_batch_desc
184179
185180 # also check if the relaxed key exists for more "general"
186181 # piecewise cudagraph
187- if relaxed_batch_descriptor in self .cudagraph_keys [CUDAGraphMode .PIECEWISE ]:
188- return CUDAGraphMode .PIECEWISE , relaxed_batch_descriptor
182+ if relaxed_batch_desc in self .cudagraph_keys [CUDAGraphMode .PIECEWISE ]:
183+ return CUDAGraphMode .PIECEWISE , relaxed_batch_desc
189184
190185 # finally, just return no cudagraphs and a trivial batch descriptor
191186 return CUDAGraphMode .NONE , BatchDescriptor (num_tokens )
0 commit comments