diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py index 726756c1..57787860 100755 --- a/trellis/modules/sparse/__init__.py +++ b/trellis/modules/sparse/__init__.py @@ -21,7 +21,7 @@ def __from_env(): BACKEND = env_sparse_backend if env_sparse_debug is not None: DEBUG = env_sparse_debug == '1' - if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa']: ATTN = env_sparse_attn print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") @@ -38,7 +38,7 @@ def set_debug(debug: bool): global DEBUG DEBUG = debug -def set_attn(attn: Literal['xformers', 'flash_attn']): +def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa']): global ATTN ATTN = attn diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py index e9e27aeb..2f974524 100755 --- a/trellis/modules/sparse/attention/full_attn.py +++ b/trellis/modules/sparse/attention/full_attn.py @@ -7,6 +7,8 @@ import xformers.ops as xops elif ATTN == 'flash_attn': import flash_attn +elif ATTN == 'sdpa': + import torch.nn.functional as F else: raise ValueError(f"Unknown attention module: {ATTN}") @@ -206,6 +208,26 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) elif num_all_args == 3: out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif ATTN == 'sdpa': + # Backend-agnostic fallback (any GPU: CUDA / ROCm / Intel / CPU). The + # packed tensors hold several variable-length sequences concatenated on + # dim 0; torch SDPA has no varlen entry point, so run one attention per + # sequence. At inference batch size is usually 1 -> a single SDPA call. + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) # each [T, H, C] + elif num_all_args == 2: + k, v = kv.unbind(dim=1) # each [T_KV, H, C] + # num_all_args == 3: q, k, v already [T, H, C] + q_starts = torch.tensor([0] + q_seqlen, device=device).cumsum(0).tolist() + kv_starts = torch.tensor([0] + kv_seqlen, device=device).cumsum(0).tolist() + outs = [] + for i in range(len(q_seqlen)): + qi = q[q_starts[i]:q_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Sq, C] + ki = k[kv_starts[i]:kv_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Skv, C] + vi = v[kv_starts[i]:kv_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Skv, C] + oi = F.scaled_dot_product_attention(qi, ki, vi) # [1, H, Sq, C] + outs.append(oi.squeeze(0).transpose(0, 1)) # [Sq, H, C] + out = torch.cat(outs, dim=0) # [T_Q, H, C] else: raise ValueError(f"Unknown attention module: {ATTN}") diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py index 5950b75b..64431b1d 100755 --- a/trellis/modules/sparse/attention/serialized_attn.py +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -9,6 +9,12 @@ import xformers.ops as xops elif ATTN == 'flash_attn': import flash_attn +elif ATTN == 'sdpa': + # Serialized (space-filling-curve) attention is not on the image->mesh + # inference path, so this module only needs to import cleanly under the + # backend-agnostic SDPA backend; calling the serialized function with + # sdpa would still require vox2seq for the ordering. + import torch.nn.functional as F else: raise ValueError(f"Unknown attention module: {ATTN}") diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py index cd642c52..3e2a5b2c 100755 --- a/trellis/modules/sparse/attention/windowed_attn.py +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -8,6 +8,8 @@ import xformers.ops as xops elif ATTN == 'flash_attn': import flash_attn +elif ATTN == 'sdpa': + import torch.nn.functional as F else: raise ValueError(f"Unknown attention module: {ATTN}") @@ -110,6 +112,13 @@ def sparse_windowed_scaled_dot_product_self_attention( out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] elif ATTN == 'flash_attn': out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + elif ATTN == 'sdpa': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + q = q.permute(0, 2, 1, 3) # [B, H, N, C] + k = k.permute(0, 2, 1, 3) # [B, H, N, C] + v = v.permute(0, 2, 1, 3) # [B, H, N, C] + out = F.scaled_dot_product_attention(q, k, v) # [B, H, N, C] + out = out.permute(0, 2, 1, 3) # [B, N, H, C] else: raise ValueError(f"Unknown attention module: {ATTN}") out = out.reshape(B * N, H, C) # [M, H, C] @@ -125,6 +134,22 @@ def sparse_windowed_scaled_dot_product_self_attention( cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ .to(qkv.device).int() out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + elif ATTN == 'sdpa': + # Per-window SDPA: ragged window sizes, no varlen SDPA kernel, so + # attend within each window separately. Backend-agnostic (any GPU). + starts = torch.tensor([0] + list(seq_lens), device=qkv.device).cumsum(0).tolist() + outs = [] + for i in range(len(seq_lens)): + seg = qkv_feats[starts[i]:starts[i + 1]] # [Si, 3, H, C] + qi, ki, vi = seg.unbind(dim=1) # each [Si, H, C] + qi = qi.transpose(0, 1).unsqueeze(0) # [1, H, Si, C] + ki = ki.transpose(0, 1).unsqueeze(0) + vi = vi.transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention(qi, ki, vi) # [1, H, Si, C] + outs.append(oi.squeeze(0).transpose(0, 1)) # [Si, H, C] + out = torch.cat(outs, dim=0) # [M, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") out = out[bwd_indices] # [T, H, C]