Skip to content

Conversation

@erfanzar
Copy link

Summary

Fixes #1169

This PR fixes two issues in the ragged-paged attention v3 kernels:

  • kernel_hd64.py (h64): Added missing sliding window mask in the kernel. The
    original code only skipped fetching KV blocks outside the window, but didn't apply
    token-level masking within partially-covered blocks.

  • kernel.py (h128): Added attention_sink support following the same pattern as the
    h64 kernel. Attention sinks allow the model to "dump" attention to a virtual token that
    doesn't contribute to the output.

Changes

File Change
kernel_hd64.py Added sliding window mask in flash_attention
kernel.py Added attention_sink parameter to all functions

… to h128

Fixes vllm-project#1169

This PR fixes two issues in the ragged paged attention v3 kernels:

1. **kernel_hd64.py (h64)**: Added missing sliding window mask in the kernel.
   The original code only skipped fetching KV blocks outside the window but
   didn't apply token-level masking within partially-covered blocks.

2. **kernel.py (h128)**: Added attention_sink support following the same
   pattern as the h64 kernel. Attention sinks allow the model to "dump"
   attention to a virtual token that doesn't contribute to the output.
   Uses LEFT concatenation semantics where sink logits are prepended
   before softmax, then removed after.

Changes:
- kernel_hd64.py: Added `if sliding_window is not None` mask in flash_attention
- kernel.py: Added attention_sink parameter to all functions (ref impl, kernel,
  prepare_inputs, validation, main function)
- kernel.py: Initialize m_prev with sink values and l_prev with 1.0 for
  proper online softmax tracking across blocks when using attention_sink
@kyuyeunk
Copy link
Collaborator

Hi @erfanzar, thanks for the quick PR. Here are few comments:

  • Sliding mask related issue in hd64 variant is being taken care of in this PR: [RPA][Kernel] Update hd64 variant sliding window code #1180
  • Not adding attention sink feature into kernel.py and only kernel_hd64.py was a deliberate choice to streamline kernel.py codebase since the feature only seems to be used for gpt-oss - which has head dim size of 64. We will reevaluate this if there is non head dim 64 model that also requires attention sink.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Sliding mask in v3 rpa might be wrong

2 participants