From 04d1c5cfa9e5aa1a1e267d4c190708b25ef9b6c5 Mon Sep 17 00:00:00 2001 From: Ed Deyzel Date: Sat, 6 Jun 2026 16:27:50 +0100 Subject: [PATCH] Add `sdpa` sparse-attention backend (run without flash_attn/xformers) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The sparse attention modules currently hard-require flash_attn or xformers, both of which ship CUDA-only kernels — so inference can't run on AMD/Intel GPUs or CPU. This adds `sdpa` as a third ATTN_BACKEND option, implementing the `full` and `windowed` sparse attention paths with torch.nn.functional.scaled_dot_product_attention (built into PyTorch, runs on every backend). Purely additive: flash_attn/xformers stay the default on NVIDIA. Verified numerically equal to a naive softmax reference (~1e-6) and end-to-end on an AMD RX 6800 (ROCm, Windows). export ATTN_BACKEND=sdpa Co-Authored-By: Claude Opus 4.8 --- trellis/modules/sparse/__init__.py | 4 +-- trellis/modules/sparse/attention/full_attn.py | 22 ++++++++++++++++ .../sparse/attention/serialized_attn.py | 6 +++++ .../modules/sparse/attention/windowed_attn.py | 25 +++++++++++++++++++ 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py index 726756c1..57787860 100755 --- a/trellis/modules/sparse/__init__.py +++ b/trellis/modules/sparse/__init__.py @@ -21,7 +21,7 @@ def __from_env(): BACKEND = env_sparse_backend if env_sparse_debug is not None: DEBUG = env_sparse_debug == '1' - if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa']: ATTN = env_sparse_attn print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") @@ -38,7 +38,7 @@ def set_debug(debug: bool): global DEBUG DEBUG = debug -def set_attn(attn: Literal['xformers', 'flash_attn']): +def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa']): global ATTN ATTN = attn diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py index e9e27aeb..2f974524 100755 --- a/trellis/modules/sparse/attention/full_attn.py +++ b/trellis/modules/sparse/attention/full_attn.py @@ -7,6 +7,8 @@ import xformers.ops as xops elif ATTN == 'flash_attn': import flash_attn +elif ATTN == 'sdpa': + import torch.nn.functional as F else: raise ValueError(f"Unknown attention module: {ATTN}") @@ -206,6 +208,26 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) elif num_all_args == 3: out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif ATTN == 'sdpa': + # Backend-agnostic fallback (any GPU: CUDA / ROCm / Intel / CPU). The + # packed tensors hold several variable-length sequences concatenated on + # dim 0; torch SDPA has no varlen entry point, so run one attention per + # sequence. At inference batch size is usually 1 -> a single SDPA call. + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) # each [T, H, C] + elif num_all_args == 2: + k, v = kv.unbind(dim=1) # each [T_KV, H, C] + # num_all_args == 3: q, k, v already [T, H, C] + q_starts = torch.tensor([0] + q_seqlen, device=device).cumsum(0).tolist() + kv_starts = torch.tensor([0] + kv_seqlen, device=device).cumsum(0).tolist() + outs = [] + for i in range(len(q_seqlen)): + qi = q[q_starts[i]:q_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Sq, C] + ki = k[kv_starts[i]:kv_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Skv, C] + vi = v[kv_starts[i]:kv_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Skv, C] + oi = F.scaled_dot_product_attention(qi, ki, vi) # [1, H, Sq, C] + outs.append(oi.squeeze(0).transpose(0, 1)) # [Sq, H, C] + out = torch.cat(outs, dim=0) # [T_Q, H, C] else: raise ValueError(f"Unknown attention module: {ATTN}") diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py index 5950b75b..64431b1d 100755 --- a/trellis/modules/sparse/attention/serialized_attn.py +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -9,6 +9,12 @@ import xformers.ops as xops elif ATTN == 'flash_attn': import flash_attn +elif ATTN == 'sdpa': + # Serialized (space-filling-curve) attention is not on the image->mesh + # inference path, so this module only needs to import cleanly under the + # backend-agnostic SDPA backend; calling the serialized function with + # sdpa would still require vox2seq for the ordering. + import torch.nn.functional as F else: raise ValueError(f"Unknown attention module: {ATTN}") diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py index cd642c52..3e2a5b2c 100755 --- a/trellis/modules/sparse/attention/windowed_attn.py +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -8,6 +8,8 @@ import xformers.ops as xops elif ATTN == 'flash_attn': import flash_attn +elif ATTN == 'sdpa': + import torch.nn.functional as F else: raise ValueError(f"Unknown attention module: {ATTN}") @@ -110,6 +112,13 @@ def sparse_windowed_scaled_dot_product_self_attention( out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] elif ATTN == 'flash_attn': out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + elif ATTN == 'sdpa': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + q = q.permute(0, 2, 1, 3) # [B, H, N, C] + k = k.permute(0, 2, 1, 3) # [B, H, N, C] + v = v.permute(0, 2, 1, 3) # [B, H, N, C] + out = F.scaled_dot_product_attention(q, k, v) # [B, H, N, C] + out = out.permute(0, 2, 1, 3) # [B, N, H, C] else: raise ValueError(f"Unknown attention module: {ATTN}") out = out.reshape(B * N, H, C) # [M, H, C] @@ -125,6 +134,22 @@ def sparse_windowed_scaled_dot_product_self_attention( cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ .to(qkv.device).int() out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + elif ATTN == 'sdpa': + # Per-window SDPA: ragged window sizes, no varlen SDPA kernel, so + # attend within each window separately. Backend-agnostic (any GPU). + starts = torch.tensor([0] + list(seq_lens), device=qkv.device).cumsum(0).tolist() + outs = [] + for i in range(len(seq_lens)): + seg = qkv_feats[starts[i]:starts[i + 1]] # [Si, 3, H, C] + qi, ki, vi = seg.unbind(dim=1) # each [Si, H, C] + qi = qi.transpose(0, 1).unsqueeze(0) # [1, H, Si, C] + ki = ki.transpose(0, 1).unsqueeze(0) + vi = vi.transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention(qi, ki, vi) # [1, H, Si, C] + outs.append(oi.squeeze(0).transpose(0, 1)) # [Si, H, C] + out = torch.cat(outs, dim=0) # [M, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") out = out[bwd_indices] # [T, H, C]