Skip to content

Commit f849175

Browse files
test fix
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 354676c commit f849175

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

vllm/v1/cudagraph_dispatcher.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class CudagraphDispatcher:
3131
def __init__(self, vllm_config: VllmConfig):
3232
self.vllm_config = vllm_config
3333
self.compilation_config = vllm_config.compilation_config
34-
self.cudagraph_mode = self.compilation_config.cudagraph_mode
3534
self.uniform_decode_query_len = (
3635
1
3736
if not self.vllm_config.speculative_config
@@ -44,12 +43,8 @@ def __init__(self, vllm_config: VllmConfig):
4443
CUDAGraphMode.FULL: set(),
4544
}
4645

47-
not_use_piecewise_compilation = (
48-
not self.cudagraph_mode.requires_piecewise_compilation()
49-
)
50-
5146
assert (
52-
not_use_piecewise_compilation
47+
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
5348
or self.compilation_config.is_attention_compiled_piecewise()
5449
), (
5550
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
@@ -75,6 +70,7 @@ def _create_padded_batch_descriptor(
7570
assert num_tokens_padded % uniform_decode_query_len == 0
7671
assert num_reqs <= max_num_seqs
7772
else:
73+
uniform_decode = False
7874
num_reqs = min(num_tokens_padded, max_num_seqs)
7975

8076
return BatchDescriptor(
@@ -95,7 +91,9 @@ def add_cudagraph_key(
9591
def initialize_cudagraph_keys(
9692
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
9793
):
98-
# This should be called only after attention backend is initialized.
94+
# This should be called only after attention backend is initialized. So we can
95+
# get the correct cudagraph mode after backend support is resolved.
96+
self.cudagraph_mode = cudagraph_mode
9997

10098
# LoRA activation cases to specialize the cuda graphs on
10199
if self.vllm_config.lora_config:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4611,8 +4611,7 @@ def _check_and_update_cudagraph_mode(
46114611

46124612
# Trigger cudagraph dispatching keys initialization after
46134613
# resolved cudagraph mode.
4614-
cudagraph_mode = self.compilation_config.cudagraph_mode
4615-
assert cudagraph_mode is not None
4614+
self.compilation_config.cudagraph_mode = cudagraph_mode
46164615
self.cudagraph_dispatcher.initialize_cudagraph_keys(
46174616
cudagraph_mode, self.uniform_decode_query_len
46184617
)

0 commit comments

Comments
 (0)