@@ -3306,80 +3306,6 @@ def _dummy_run(
33063306 dp_rank = self .parallel_config .data_parallel_rank
33073307 num_tokens_after_padding = int (num_tokens_across_dp [dp_rank ])
33083308
3309- attn_metadata : PerLayerAttnMetadata | None = None
3310-
3311- # If force_attention is True, we always capture attention. Otherwise,
3312- # it only happens for cudagraph_runtime_mode=FULL.
3313- if force_attention or cudagraph_runtime_mode == CUDAGraphMode .FULL :
3314- attn_metadata = {}
3315- if ubatch_slices is not None :
3316- attn_metadata = [dict () for _ in range (len (ubatch_slices ))]
3317-
3318- if create_mixed_batch :
3319- # In the mixed batch mode (used for FI warmup), we use
3320- # shorter sequence lengths to run faster.
3321- # TODO(luka) better system for describing dummy batches
3322- seq_lens = [1 ] * num_decode_tokens + [num_prefill_tokens + 1 ]
3323- else :
3324- seq_lens = max_query_len
3325- self .seq_lens .np [:num_reqs ] = seq_lens
3326- self .seq_lens .np [num_reqs :] = 0
3327- self .seq_lens .copy_to_gpu ()
3328-
3329- cum_num_tokens , _ = self ._get_cumsum_and_arange (num_scheduled_tokens )
3330- self .query_start_loc .np [1 : num_reqs + 1 ] = cum_num_tokens
3331- self .query_start_loc .copy_to_gpu ()
3332-
3333- for kv_cache_group_id , kv_cache_group_spec in enumerate (
3334- self .kv_cache_config .kv_cache_groups
3335- ):
3336- common_attn_metadata = CommonAttentionMetadata (
3337- query_start_loc = self .query_start_loc .gpu [: num_reqs + 1 ],
3338- query_start_loc_cpu = self .query_start_loc .cpu [: num_reqs + 1 ],
3339- seq_lens = self .seq_lens .gpu [:num_reqs ],
3340- seq_lens_cpu = self .seq_lens .cpu [:num_reqs ],
3341- num_computed_tokens_cpu = self .input_batch .num_computed_tokens_cpu_tensor [
3342- :num_reqs
3343- ],
3344- num_reqs = num_reqs ,
3345- num_actual_tokens = num_tokens ,
3346- max_query_len = max_query_len ,
3347- max_seq_len = self .max_model_len ,
3348- block_table_tensor = self .input_batch .block_table [
3349- kv_cache_group_id
3350- ].get_device_tensor (num_reqs ),
3351- slot_mapping = self .input_batch .block_table [
3352- kv_cache_group_id
3353- ].slot_mapping .gpu [:num_tokens ],
3354- causal = True ,
3355- dcp_local_seq_lens = self .dcp_local_seq_lens .gpu [:num_reqs ]
3356- if self .dcp_world_size > 1
3357- else None ,
3358- )
3359- for attn_group in self .attn_groups [kv_cache_group_id ]:
3360- if ubatch_slices is not None :
3361- common_attn_metadata_list = split_attn_metadata (
3362- ubatch_slices , common_attn_metadata
3363- )
3364- for ubid , common_attn_metadata in enumerate (
3365- common_attn_metadata_list
3366- ):
3367- assert common_attn_metadata .max_query_len == 1
3368- attn_metadata_i = attn_group .get_metadata_builder (
3369- ubatch_id = ubid
3370- ).build_for_cudagraph_capture (common_attn_metadata )
3371- for layer_name in attn_group .layer_names :
3372- assert type (attn_metadata ) is list
3373- attn_metadata [ubid ][layer_name ] = attn_metadata_i
3374- else :
3375- assert type (attn_metadata ) is dict
3376- metadata_builder = attn_group .get_metadata_builder ()
3377- attn_metadata_i = metadata_builder .build_for_cudagraph_capture (
3378- common_attn_metadata
3379- )
3380- for layer_name in attn_group .layer_names :
3381- attn_metadata [layer_name ] = attn_metadata_i
3382-
33833309 with self .maybe_dummy_run_with_lora (
33843310 self .lora_config , num_scheduled_tokens , activate_lora , remove_lora
33853311 ):
@@ -3447,6 +3373,80 @@ def _dummy_run(
34473373 else :
34483374 cudagraph_runtime_mode = _cg_mode
34493375
3376+ attn_metadata : PerLayerAttnMetadata | None = None
3377+
3378+ # If force_attention is True, we always capture attention. Otherwise,
3379+ # it only happens for cudagraph_runtime_mode=FULL.
3380+ if force_attention or cudagraph_runtime_mode == CUDAGraphMode .FULL :
3381+ attn_metadata = {}
3382+ if ubatch_slices is not None :
3383+ attn_metadata = [dict () for _ in range (len (ubatch_slices ))]
3384+
3385+ if create_mixed_batch :
3386+ # In the mixed batch mode (used for FI warmup), we use
3387+ # shorter sequence lengths to run faster.
3388+ # TODO(luka) better system for describing dummy batches
3389+ seq_lens = [1 ] * num_decode_tokens + [num_prefill_tokens + 1 ]
3390+ else :
3391+ seq_lens = max_query_len
3392+ self .seq_lens .np [:num_reqs ] = seq_lens
3393+ self .seq_lens .np [num_reqs :] = 0
3394+ self .seq_lens .copy_to_gpu ()
3395+
3396+ cum_num_tokens , _ = self ._get_cumsum_and_arange (num_scheduled_tokens )
3397+ self .query_start_loc .np [1 : num_reqs + 1 ] = cum_num_tokens
3398+ self .query_start_loc .copy_to_gpu ()
3399+
3400+ for kv_cache_group_id , kv_cache_group_spec in enumerate (
3401+ self .kv_cache_config .kv_cache_groups
3402+ ):
3403+ common_attn_metadata = CommonAttentionMetadata (
3404+ query_start_loc = self .query_start_loc .gpu [: num_reqs + 1 ],
3405+ query_start_loc_cpu = self .query_start_loc .cpu [: num_reqs + 1 ],
3406+ seq_lens = self .seq_lens .gpu [:num_reqs ],
3407+ seq_lens_cpu = self .seq_lens .cpu [:num_reqs ],
3408+ num_computed_tokens_cpu = self .input_batch .num_computed_tokens_cpu_tensor [
3409+ :num_reqs
3410+ ],
3411+ num_reqs = num_reqs ,
3412+ num_actual_tokens = num_tokens ,
3413+ max_query_len = max_query_len ,
3414+ max_seq_len = self .max_model_len ,
3415+ block_table_tensor = self .input_batch .block_table [
3416+ kv_cache_group_id
3417+ ].get_device_tensor (num_reqs ),
3418+ slot_mapping = self .input_batch .block_table [
3419+ kv_cache_group_id
3420+ ].slot_mapping .gpu [:num_tokens ],
3421+ causal = True ,
3422+ dcp_local_seq_lens = self .dcp_local_seq_lens .gpu [:num_reqs ]
3423+ if self .dcp_world_size > 1
3424+ else None ,
3425+ )
3426+ for attn_group in self .attn_groups [kv_cache_group_id ]:
3427+ if ubatch_slices is not None :
3428+ common_attn_metadata_list = split_attn_metadata (
3429+ ubatch_slices , common_attn_metadata
3430+ )
3431+ for ubid , common_attn_metadata in enumerate (
3432+ common_attn_metadata_list
3433+ ):
3434+ assert common_attn_metadata .max_query_len == 1
3435+ attn_metadata_i = attn_group .get_metadata_builder (
3436+ ubatch_id = ubid
3437+ ).build_for_cudagraph_capture (common_attn_metadata )
3438+ for layer_name in attn_group .layer_names :
3439+ assert type (attn_metadata ) is list
3440+ attn_metadata [ubid ][layer_name ] = attn_metadata_i
3441+ else :
3442+ assert type (attn_metadata ) is dict
3443+ metadata_builder = attn_group .get_metadata_builder ()
3444+ attn_metadata_i = metadata_builder .build_for_cudagraph_capture (
3445+ common_attn_metadata
3446+ )
3447+ for layer_name in attn_group .layer_names :
3448+ attn_metadata [layer_name ] = attn_metadata_i
3449+
34503450 if ubatch_slices is not None :
34513451 # Adjust values to reflect a single ubatch.
34523452 # TODO(sage,lucas): this is cruft that should be addressed in
0 commit comments