Add sdpa sparse-attention backend (run without flash_attn/xformers)#357
Open
hanvansolo wants to merge 1 commit into
Open
Add sdpa sparse-attention backend (run without flash_attn/xformers)#357hanvansolo wants to merge 1 commit into
sdpa sparse-attention backend (run without flash_attn/xformers)#357hanvansolo wants to merge 1 commit into
Conversation
The sparse attention modules currently hard-require flash_attn or xformers, both of which ship CUDA-only kernels — so inference can't run on AMD/Intel GPUs or CPU. This adds `sdpa` as a third ATTN_BACKEND option, implementing the `full` and `windowed` sparse attention paths with torch.nn.functional.scaled_dot_product_attention (built into PyTorch, runs on every backend). Purely additive: flash_attn/xformers stay the default on NVIDIA. Verified numerically equal to a naive softmax reference (~1e-6) and end-to-end on an AMD RX 6800 (ROCm, Windows). export ATTN_BACKEND=sdpa Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
@hanvansolo please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Adds
sdpaas a third option forATTN_BACKEND/SPARSE_ATTN_BACKEND, alongside the existingflash_attnandxformers. It implements the sparse full and windowed attention paths withtorch.nn.functional.scaled_dot_product_attention, which ships with PyTorch and runs on every backend torch supports.Why
Today the sparse attention modules hard-require
flash_attnorxformers:Both ship CUDA-only kernels, so they don't install on AMD/Intel GPUs (or on CPU), which blocks TRELLIS inference on that hardware. SDPA is a built-in, backend-agnostic fallback — it lets the image→mesh path run on ROCm / XPU / MPS / CPU with no extra dependency and no build step. On NVIDIA nothing changes:
flash_attn/xformersstay the default and remain selectable; this is purely additive.The image→mesh path uses only the
full(DiTs) andwindowed(mesh decoder,attn_mode="swin") modes, so those are what's implemented.serialized(which needsvox2seq) is left untouched — its module just imports cleanly under the new backend.How
trellis/modules/sparse/__init__.py: accept'sdpa'for the attention backend env var and inset_attn.trellis/modules/sparse/attention/full_attn.py: SDPA branch. The packed variable-length sequences have no varlen SDPA kernel, so it runs one attention per sequence (a single call at batch size 1 — the inference case).trellis/modules/sparse/attention/windowed_attn.py: SDPA branch for both the uniform-window (batched SDPA) and ragged-window (per-window) cases.serialized_attn.py: import guard toleratessdpa.Correctness
Verified numerically equal to a naïve softmax-attention reference (max-abs error ~1e-6 for both the full varlen and windowed uniform cases).
Usage
export ATTN_BACKEND=sdpaTested
torchsparse-compatible shim — which can follow as a separate change if there's interest).Notes
The pure-torch path trades peak throughput for portability; on NVIDIA the native kernels remain the faster default. This PR only adds the option.