@@ -3740,6 +3740,31 @@ def _dummy_run(
37403740 dp_rank = self .parallel_config .data_parallel_rank
37413741 num_tokens_after_padding = int (num_tokens_across_dp [dp_rank ])
37423742
3743+ # filter out the valid batch descriptor
3744+ _cg_mode , batch_descriptor = (
3745+ self .cudagraph_dispatcher .dispatch (
3746+ BatchDescriptor (
3747+ num_tokens = num_tokens_after_padding ,
3748+ uniform_decode = uniform_decode ,
3749+ has_lora = activate_lora and self .lora_config is not None ,
3750+ )
3751+ )
3752+ if not is_profile
3753+ else (CUDAGraphMode .NONE , None )
3754+ )
3755+ if cudagraph_runtime_mode is not None :
3756+ # we allow forcing NONE when the dispatcher disagrees to support
3757+ # warm ups for cudagraph capture
3758+ assert (
3759+ cudagraph_runtime_mode == CUDAGraphMode .NONE
3760+ or cudagraph_runtime_mode == _cg_mode
3761+ ), (
3762+ f"Cudagraph runtime mode mismatch at dummy_run. "
3763+ f"Expected { _cg_mode } , but got { cudagraph_runtime_mode } ."
3764+ )
3765+ else :
3766+ cudagraph_runtime_mode = _cg_mode
3767+
37433768 attn_metadata : PerLayerAttnMetadata | None = None
37443769
37453770 # If force_attention is True, we always capture attention. Otherwise,
@@ -3814,31 +3839,6 @@ def _dummy_run(
38143839 num_tokens_after_padding , None , False
38153840 )
38163841
3817- # filter out the valid batch descriptor
3818- _cg_mode , batch_descriptor = (
3819- self .cudagraph_dispatcher .dispatch (
3820- BatchDescriptor (
3821- num_tokens = num_tokens_after_padding ,
3822- uniform_decode = uniform_decode ,
3823- has_lora = activate_lora and self .lora_config is not None ,
3824- )
3825- )
3826- if not is_profile
3827- else (CUDAGraphMode .NONE , None )
3828- )
3829- if cudagraph_runtime_mode is not None :
3830- # we allow forcing NONE when the dispatcher disagrees to support
3831- # warm ups for cudagraph capture
3832- assert (
3833- cudagraph_runtime_mode == CUDAGraphMode .NONE
3834- or cudagraph_runtime_mode == _cg_mode
3835- ), (
3836- f"Cudagraph runtime mode mismatch at dummy_run. "
3837- f"Expected { _cg_mode } , but got { cudagraph_runtime_mode } ."
3838- )
3839- else :
3840- cudagraph_runtime_mode = _cg_mode
3841-
38423842 if ubatch_slices is not None :
38433843 # Adjust values to reflect a single ubatch.
38443844 # TODO(sage,lucas): this is cruft that should be addressed in
0 commit comments