@@ -31,7 +31,6 @@ class CudagraphDispatcher:
3131 def __init__ (self , vllm_config : VllmConfig ):
3232 self .vllm_config = vllm_config
3333 self .compilation_config = vllm_config .compilation_config
34- self .cudagraph_mode = self .compilation_config .cudagraph_mode
3534 self .uniform_decode_query_len = (
3635 1
3736 if not self .vllm_config .speculative_config
@@ -44,12 +43,8 @@ def __init__(self, vllm_config: VllmConfig):
4443 CUDAGraphMode .FULL : set (),
4544 }
4645
47- not_use_piecewise_compilation = (
48- not self .cudagraph_mode .requires_piecewise_compilation ()
49- )
50-
5146 assert (
52- not_use_piecewise_compilation
47+ self . compilation_config . cudagraph_mode . requires_piecewise_compilation ()
5348 or self .compilation_config .is_attention_compiled_piecewise ()
5449 ), (
5550 "Compilation mode should be CompilationMode.VLLM_COMPILE when "
@@ -75,6 +70,7 @@ def _create_padded_batch_descriptor(
7570 assert num_tokens_padded % uniform_decode_query_len == 0
7671 assert num_reqs <= max_num_seqs
7772 else :
73+ uniform_decode = False
7874 num_reqs = min (num_tokens_padded , max_num_seqs )
7975
8076 return BatchDescriptor (
@@ -95,7 +91,9 @@ def add_cudagraph_key(
9591 def initialize_cudagraph_keys (
9692 self , cudagraph_mode : CUDAGraphMode , uniform_decode_query_len : int
9793 ):
98- # This should be called only after attention backend is initialized.
94+ # This should be called only after attention backend is initialized. So we can
95+ # get the correct cudagraph mode after backend support is resolved.
96+ self .cudagraph_mode = cudagraph_mode
9997
10098 # LoRA activation cases to specialize the cuda graphs on
10199 if self .vllm_config .lora_config :
0 commit comments