@@ -3521,6 +3521,31 @@ def _dummy_run(
35213521 dp_rank = self .parallel_config .data_parallel_rank
35223522 num_tokens_after_padding = int (num_tokens_across_dp [dp_rank ])
35233523
3524+ # filter out the valid batch descriptor
3525+ _cg_mode , batch_descriptor = (
3526+ self .cudagraph_dispatcher .dispatch (
3527+ BatchDescriptor (
3528+ num_tokens = num_tokens_after_padding ,
3529+ uniform_decode = uniform_decode ,
3530+ has_lora = activate_lora and self .lora_config is not None ,
3531+ )
3532+ )
3533+ if not is_profile
3534+ else (CUDAGraphMode .NONE , None )
3535+ )
3536+ if cudagraph_runtime_mode is not None :
3537+ # we allow forcing NONE when the dispatcher disagrees to support
3538+ # warm ups for cudagraph capture
3539+ assert (
3540+ cudagraph_runtime_mode == CUDAGraphMode .NONE
3541+ or cudagraph_runtime_mode == _cg_mode
3542+ ), (
3543+ f"Cudagraph runtime mode mismatch at dummy_run. "
3544+ f"Expected { _cg_mode } , but got { cudagraph_runtime_mode } ."
3545+ )
3546+ else :
3547+ cudagraph_runtime_mode = _cg_mode
3548+
35243549 attn_metadata : PerLayerAttnMetadata | None = None
35253550
35263551 # If force_attention is True, we always capture attention. Otherwise,
@@ -3595,31 +3620,6 @@ def _dummy_run(
35953620 num_tokens_after_padding , None , False
35963621 )
35973622
3598- # filter out the valid batch descriptor
3599- _cg_mode , batch_descriptor = (
3600- self .cudagraph_dispatcher .dispatch (
3601- BatchDescriptor (
3602- num_tokens = num_tokens_after_padding ,
3603- uniform_decode = uniform_decode ,
3604- has_lora = activate_lora and self .lora_config is not None ,
3605- )
3606- )
3607- if not is_profile
3608- else (CUDAGraphMode .NONE , None )
3609- )
3610- if cudagraph_runtime_mode is not None :
3611- # we allow forcing NONE when the dispatcher disagrees to support
3612- # warm ups for cudagraph capture
3613- assert (
3614- cudagraph_runtime_mode == CUDAGraphMode .NONE
3615- or cudagraph_runtime_mode == _cg_mode
3616- ), (
3617- f"Cudagraph runtime mode mismatch at dummy_run. "
3618- f"Expected { _cg_mode } , but got { cudagraph_runtime_mode } ."
3619- )
3620- else :
3621- cudagraph_runtime_mode = _cg_mode
3622-
36233623 if ubatch_slices is not None :
36243624 # Adjust values to reflect a single ubatch.
36253625 # TODO(sage,lucas): this is cruft that should be addressed in
0 commit comments