cuDNN: add fused scaled-dot-product (flash) attention forward#3174
Draft
CarloLucibello wants to merge 2 commits into
Draft
cuDNN: add fused scaled-dot-product (flash) attention forward#3174CarloLucibello wants to merge 2 commits into
CarloLucibello wants to merge 2 commits into
Conversation
Wrap cuDNN's modern fused SDPA via the backend graph API. The legacy `cudnnMultiHeadAttnForward` is the deprecated attention path; this adds the flash-attention kernel NVIDIA now recommends. - backend.jl: a thin typed layer over the cuDNN backend graph API (`cudnnBackend*`): descriptor wrapper, setattr!/getattr, and helpers to build an operation graph, run engine heuristics, finalize an execution plan, and execute a variant pack. No prior high-level wrapper used this API. - sdpa.jl: `cudnnSDPAForward[!]` driving the dedicated `CUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR`, with a per-shape execution-plan cache. Inputs are 4-D (head_dim, nheads, seq_len, batch) — NNlib's attention layout — so no permute is needed to interoperate. Scope: forward inference only, Float16/BFloat16 (cuDNN's fused engine does not support Float32/Float64). Verified vs a Float32 reference (relerr ~5e-4). Not yet supported, both blocked on cuDNN <= 9.20 and documented inline: - causal masking: the SDPA score-modifier subgraph needs cuDNN >= 9.21 (no CUDNN_jll yet); block-mask is block-sparse; the primitive matmul->softmax ->matmul graph yields no fused engine from raw backend calls. - backward: `cudnnSDPABackward` is a documented placeholder; the dedicated SDPA_BWD descriptor does not finalize on 9.20 and the supported path needs the same 9.21 subgraph mechanism (forward stats output already verified to work). Tests cover the forward against a dense reference (Float16, several shapes, custom scale) plus in-place agreement; gated on compute capability >= 8.0. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Constrain cudnnSDPAForward[!] inputs to DenseCuArray{T,4}: the cached plan
bakes in dense column-major strides, so non-contiguous views (or host Arrays,
whose Ptr/CuPtr distinction the variant pack erases) previously passed all
size checks and produced silently wrong results or illegal accesses.
- Key the plan cache on the current context: execution plans are finalized
against a specific device's handle, so a plan built on one GPU must not be
executed on another.
- Only swallow the CUDNN_STATUS_NOT_SUPPORTED family (3000s) in
try_execution_plan; BAD_PARAM/INTERNAL_ERROR now propagate instead of being
misreported as "no supported engine". The terminal error also distinguishes
"heuristic returned no configs" from "N configs failed to finalize".
- Build plans outside the cache lock (matching the descriptors.jl pattern) so
concurrent calls don't serialize behind a multi-millisecond plan build; a
racing duplicate build is benign and resolved by get! on insert.
- Use with_workspace for the execute workspace, freeing it eagerly instead of
leaving it to GC (house style, cf. convolution.jl/reduce.jl).
- Test that views, host arrays, and Float32 inputs are rejected.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
36ba1ff to
65ed835
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Wrap cuDNN's modern fused Scaled Dot Product Attention (SDPA) via the backend graph API. The legacy
cudnnMultiHeadAttnForwardis the deprecated attention path; this adds the flash-attention kernel NVIDIA now recommends.cudnnBackend*): descriptor wrapper, setattr!/getattr, and helpers to build an operation graph, run engine heuristics, finalize an execution plan, and execute a variant pack. No prior high-level wrapper used this API.cudnnSDPAForward[!]driving the dedicatedCUDNN_BACKEND_OPERATION_SDPA_FWD_DESCRIPTOR, with a per-shape execution-plan cache. Inputs are 4-D (head_dim, nheads, seq_len, batch) — NNlib's attention layout — so no permute is needed to interoperate.Scope: forward inference only, Float16/BFloat16 (cuDNN's fused engine does not support Float32/Float64). Verified vs a Float32 reference (relerr ~5e-4).
Not yet supported, both blocked on cuDNN <= 9.20 and documented inline:
cudnnSDPABackwardis a documented placeholder; the dedicated SDPA_BWD descriptor does not finalize on 9.20 and the supported path needs the same 9.21 subgraph mechanism (forward stats output already verified to work).Tests cover the forward against a dense reference (Float16, several shapes, custom scale) plus in-place agreement; gated on compute capability >= 8.0.
Related to #2266. We put in place some basic infrastructure for the graph API, while focusing on the attention operator.
This PR has been mostly AI-generated and AI-reviewed, with tests passing locally on a RTX 5090.
No existing functionality is affected.