Skip to content

cuDNN: add fused scaled-dot-product (flash) attention forward#3174

Draft
CarloLucibello wants to merge 2 commits into
JuliaGPU:mainfrom
CarloLucibello:cl/cudnn-sdpa
Draft

cuDNN: add fused scaled-dot-product (flash) attention forward#3174
CarloLucibello wants to merge 2 commits into
JuliaGPU:mainfrom
CarloLucibello:cl/cudnn-sdpa

Conversation

@CarloLucibello

@CarloLucibello CarloLucibello commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Wrap cuDNN's modern fused Scaled Dot Product Attention (SDPA) via the backend graph API. The legacy cudnnMultiHeadAttnForward is the deprecated attention path; this adds the flash-attention kernel NVIDIA now recommends.

  • backend.jl: a thin typed layer over the cuDNN backend graph API (cudnnBackend*): descriptor wrapper, setattr!/getattr, and helpers to build an operation graph, run engine heuristics, finalize an execution plan, and execute a variant pack. No prior high-level wrapper used this API.
  • sdpa.jl: cudnnSDPAForward[!] driving the dedicated CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR, with a per-shape execution-plan cache. Inputs are 4-D (head_dim, nheads, seq_len, batch) — NNlib's attention layout — so no permute is needed to interoperate.

Scope: forward inference only, Float16/BFloat16 (cuDNN's fused engine does not support Float32/Float64). Verified vs a Float32 reference (relerr ~5e-4).

Not yet supported, both blocked on cuDNN <= 9.20 and documented inline:

  • causal masking: the SDPA score-modifier subgraph needs cuDNN >= 9.21 (no CUDNN_jll yet); block-mask is block-sparse; the primitive matmul->softmax ->matmul graph yields no fused engine from raw backend calls.
  • backward: cudnnSDPABackward is a documented placeholder; the dedicated SDPA_BWD descriptor does not finalize on 9.20 and the supported path needs the same 9.21 subgraph mechanism (forward stats output already verified to work).

Tests cover the forward against a dense reference (Float16, several shapes, custom scale) plus in-place agreement; gated on compute capability >= 8.0.

Related to #2266. We put in place some basic infrastructure for the graph API, while focusing on the attention operator.

This PR has been mostly AI-generated and AI-reviewed, with tests passing locally on a RTX 5090.
No existing functionality is affected.

@CarloLucibello CarloLucibello marked this pull request as draft June 11, 2026 13:21
CarloLucibello and others added 2 commits June 11, 2026 15:28
Wrap cuDNN's modern fused SDPA via the backend graph API. The legacy
`cudnnMultiHeadAttnForward` is the deprecated attention path; this adds the
flash-attention kernel NVIDIA now recommends.

- backend.jl: a thin typed layer over the cuDNN backend graph API
  (`cudnnBackend*`): descriptor wrapper, setattr!/getattr, and helpers to build
  an operation graph, run engine heuristics, finalize an execution plan, and
  execute a variant pack. No prior high-level wrapper used this API.
- sdpa.jl: `cudnnSDPAForward[!]` driving the dedicated
  `CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR`, with a per-shape execution-plan
  cache. Inputs are 4-D (head_dim, nheads, seq_len, batch) — NNlib's attention
  layout — so no permute is needed to interoperate.

Scope: forward inference only, Float16/BFloat16 (cuDNN's fused engine does not
support Float32/Float64). Verified vs a Float32 reference (relerr ~5e-4).

Not yet supported, both blocked on cuDNN <= 9.20 and documented inline:
- causal masking: the SDPA score-modifier subgraph needs cuDNN >= 9.21 (no
  CUDNN_jll yet); block-mask is block-sparse; the primitive matmul->softmax
  ->matmul graph yields no fused engine from raw backend calls.
- backward: `cudnnSDPABackward` is a documented placeholder; the dedicated
  SDPA_BWD descriptor does not finalize on 9.20 and the supported path needs the
  same 9.21 subgraph mechanism (forward stats output already verified to work).

Tests cover the forward against a dense reference (Float16, several shapes,
custom scale) plus in-place agreement; gated on compute capability >= 8.0.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Constrain cudnnSDPAForward[!] inputs to DenseCuArray{T,4}: the cached plan
  bakes in dense column-major strides, so non-contiguous views (or host Arrays,
  whose Ptr/CuPtr distinction the variant pack erases) previously passed all
  size checks and produced silently wrong results or illegal accesses.
- Key the plan cache on the current context: execution plans are finalized
  against a specific device's handle, so a plan built on one GPU must not be
  executed on another.
- Only swallow the CUDNN_STATUS_NOT_SUPPORTED family (3000s) in
  try_execution_plan; BAD_PARAM/INTERNAL_ERROR now propagate instead of being
  misreported as "no supported engine". The terminal error also distinguishes
  "heuristic returned no configs" from "N configs failed to finalize".
- Build plans outside the cache lock (matching the descriptors.jl pattern) so
  concurrent calls don't serialize behind a multi-millisecond plan build; a
  racing duplicate build is benign and resolved by get! on insert.
- Use with_workspace for the execute workspace, freeing it eagerly instead of
  leaving it to GC (house style, cf. convolution.jl/reduce.jl).
- Test that views, host arrays, and Float32 inputs are rejected.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant