Skip to content

Conversation

@yiz-liu
Copy link
Contributor

@yiz-liu yiz-liu commented Oct 23, 2025

Purpose

After carefully reading the code, I found a potential edge case: when execute_dummy_batch runs, dummy attention metadata isn't created even if aclgraph_runtime_mode is later set to CUDAGraphMode.FULL. That's odd, because attention normally requires metadata, otherwise it may raise an error or produce incorrect output.

The only explanation I can think of is that we're skipping metadata creation for dummy batches to save a bit of performance since we don't care about their output. Can anyone elaborate on this? Thanks.

I also propose a potential fix by moving CUDA graph dispatch logic earlier, this ensures metadata is built when replaying a CUDA graph, and the performance impact should be negligible.

Test Plan

None.

Test Result

None.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Oct 23, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies and fixes a potential bug where dummy attention metadata might not be created when cudagraph_runtime_mode is set to FULL. By deferring the metadata creation until after cudagraph_runtime_mode is determined, the change ensures correctness. The implementation is a straightforward move of a code block, and it looks correct. I've added one suggestion to improve maintainability by refactoring a large block of duplicated code.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_after_padding <= self.max_num_tokens
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:

P1 Badge Populate seq_lens before pooling model kwargs

_dummy_run now invokes _init_model_kwargs before seq_lens is filled for the current dummy batch—the lengths are only written later when force_attention is true or the dispatcher returns CUDAGraphMode.FULL. _init_model_kwargs reads self.seq_lens to build token_type_ids for pooling models, so a dummy run that captures attention metadata on a pooling model will use stale lengths from a previous call, yielding token_type_ids whose size no longer matches the current num_tokens_after_padding and causing incorrect inputs during CUDA‑graph warmup. Ensure the seq lens tensor is populated before calling _init_model_kwargs or defer that call until after the lengths are updated.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Oct 30, 2025

@WoosukKwon Could you please take a look at this? Thanks!

@fhl2000
Copy link
Contributor

fhl2000 commented Nov 3, 2025

Yep, it is a good catch. The flashInfer backend may potentially hang if it is at full cudagraph without preparing attn_metadata. Could you please move the code of attn_metadata building back out of the with context, instead, moving the cudagraph dispatching parts above?

@yiz-liu
Copy link
Contributor Author

yiz-liu commented Nov 12, 2025

Yep, it is a good catch. The flashInfer backend may potentially hang if it is at full cudagraph without preparing attn_metadata. Could you please move the code of attn_metadata building back out of the with context, instead, moving the cudagraph dispatching parts above?

OK, will do.

Moves the CUDA graph dispatch logic to execute before the attention metadata is calculated within the dummy run.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
@yiz-liu yiz-liu changed the title Refactor: Defer dummy attention metadata creation Refactor: Move CUDA graph dispatch logic earlier Nov 14, 2025
@mergify mergify bot added the nvidia label Nov 14, 2025
Copy link
Contributor

@fhl2000 fhl2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for this late review. I think there is no harm in moving this logic earlier in dummy_run.
cc @LucasWilkinson, it is also closer to your idea in that the padding logic (and cg mode) verified by the cudagraph dispatcher is done before attention metadata building.

@fhl2000
Copy link
Contributor

fhl2000 commented Nov 19, 2025

cc @ProExpertProg

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @LucasWilkinson is this ok with you?

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Nov 19, 2025
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 19, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; will be fixed by: #28579 but we can take this in the interim

Copilot AI review requested due to automatic review settings November 21, 2025 23:37
Copilot finished reviewing on behalf of yiz-liu November 21, 2025 23:53
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the _dummy_run method in the GPU model runner by moving the CUDA graph dispatch logic to execute earlier in the function flow. The change ensures that cudagraph_runtime_mode is determined before the attention metadata creation decision, addressing a potential edge case where attention metadata might not be created when replaying a CUDA graph in FULL mode.

Key Changes

  • CUDA graph dispatch logic moved from inside the LoRA context (after intermediate tensor setup) to immediately after num_tokens_after_padding calculation
  • This ensures cudagraph_runtime_mode is set before the condition if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: is evaluated for attention metadata creation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ProExpertProg ProExpertProg merged commit df78aee into vllm-project:main Nov 22, 2025
45 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Nov 22, 2025
ywang96 pushed a commit to ywang96/vllm that referenced this pull request Nov 23, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
lpapavassiliou pushed a commit to lpapavassiliou/vllm that referenced this pull request Nov 24, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
charlotte12l pushed a commit to charlotte12l/vllm that referenced this pull request Dec 5, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Zhathw pushed a commit to Zhathw/vllm that referenced this pull request Dec 6, 2025
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants