Skip to content

Commit 354676c

Browse files
more test fixes
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent fd68758 commit 354676c

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

vllm/v1/cudagraph_dispatcher.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,12 @@ def _create_padded_batch_descriptor(
7070
uniform_decode_query_len = self.uniform_decode_query_len
7171
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
7272

73-
if uniform_decode:
73+
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
7474
num_reqs = num_tokens_padded // uniform_decode_query_len
7575
assert num_tokens_padded % uniform_decode_query_len == 0
7676
assert num_reqs <= max_num_seqs
77-
return BatchDescriptor(
78-
num_tokens=num_tokens_padded,
79-
num_reqs=num_reqs,
80-
uniform=uniform_decode,
81-
has_lora=has_lora,
82-
)
83-
num_reqs = min(num_tokens_padded, max_num_seqs)
77+
else:
78+
num_reqs = min(num_tokens_padded, max_num_seqs)
8479

8580
return BatchDescriptor(
8681
num_tokens=num_tokens_padded,
@@ -168,24 +163,24 @@ def dispatch(
168163
):
169164
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
170165

171-
batch_descriptor = self._create_padded_batch_descriptor(
166+
batch_desc = self._create_padded_batch_descriptor(
172167
num_tokens, uniform_decode, has_lora
173168
)
174-
relaxed_batch_descriptor = batch_descriptor.relax_for_mixed_batch_cudagraphs()
169+
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
175170

176171
if not use_cascade_attn:
177172
# check if key exists for full cudagraph
178-
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
179-
return CUDAGraphMode.FULL, batch_descriptor
173+
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
174+
return CUDAGraphMode.FULL, batch_desc
180175

181176
# otherwise, check if the relaxed key exists
182-
if relaxed_batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
183-
return CUDAGraphMode.FULL, relaxed_batch_descriptor
177+
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
178+
return CUDAGraphMode.FULL, relaxed_batch_desc
184179

185180
# also check if the relaxed key exists for more "general"
186181
# piecewise cudagraph
187-
if relaxed_batch_descriptor in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
188-
return CUDAGraphMode.PIECEWISE, relaxed_batch_descriptor
182+
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
183+
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
189184

190185
# finally, just return no cudagraphs and a trivial batch descriptor
191186
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

0 commit comments

Comments
 (0)