Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions trellis/modules/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __from_env():
BACKEND = env_sparse_backend
if env_sparse_debug is not None:
DEBUG = env_sparse_debug == '1'
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa']:
ATTN = env_sparse_attn

print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
Expand All @@ -38,7 +38,7 @@ def set_debug(debug: bool):
global DEBUG
DEBUG = debug

def set_attn(attn: Literal['xformers', 'flash_attn']):
def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa']):
global ATTN
ATTN = attn

Expand Down
22 changes: 22 additions & 0 deletions trellis/modules/sparse/attention/full_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import xformers.ops as xops
elif ATTN == 'flash_attn':
import flash_attn
elif ATTN == 'sdpa':
import torch.nn.functional as F
else:
raise ValueError(f"Unknown attention module: {ATTN}")

Expand Down Expand Up @@ -206,6 +208,26 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif ATTN == 'sdpa':
# Backend-agnostic fallback (any GPU: CUDA / ROCm / Intel / CPU). The
# packed tensors hold several variable-length sequences concatenated on
# dim 0; torch SDPA has no varlen entry point, so run one attention per
# sequence. At inference batch size is usually 1 -> a single SDPA call.
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1) # each [T, H, C]
elif num_all_args == 2:
k, v = kv.unbind(dim=1) # each [T_KV, H, C]
# num_all_args == 3: q, k, v already [T, H, C]
q_starts = torch.tensor([0] + q_seqlen, device=device).cumsum(0).tolist()
kv_starts = torch.tensor([0] + kv_seqlen, device=device).cumsum(0).tolist()
outs = []
for i in range(len(q_seqlen)):
qi = q[q_starts[i]:q_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Sq, C]
ki = k[kv_starts[i]:kv_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Skv, C]
vi = v[kv_starts[i]:kv_starts[i + 1]].transpose(0, 1).unsqueeze(0) # [1, H, Skv, C]
oi = F.scaled_dot_product_attention(qi, ki, vi) # [1, H, Sq, C]
outs.append(oi.squeeze(0).transpose(0, 1)) # [Sq, H, C]
out = torch.cat(outs, dim=0) # [T_Q, H, C]
else:
raise ValueError(f"Unknown attention module: {ATTN}")

Expand Down
6 changes: 6 additions & 0 deletions trellis/modules/sparse/attention/serialized_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import xformers.ops as xops
elif ATTN == 'flash_attn':
import flash_attn
elif ATTN == 'sdpa':
# Serialized (space-filling-curve) attention is not on the image->mesh
# inference path, so this module only needs to import cleanly under the
# backend-agnostic SDPA backend; calling the serialized function with
# sdpa would still require vox2seq for the ordering.
import torch.nn.functional as F
else:
raise ValueError(f"Unknown attention module: {ATTN}")

Expand Down
25 changes: 25 additions & 0 deletions trellis/modules/sparse/attention/windowed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import xformers.ops as xops
elif ATTN == 'flash_attn':
import flash_attn
elif ATTN == 'sdpa':
import torch.nn.functional as F
else:
raise ValueError(f"Unknown attention module: {ATTN}")

Expand Down Expand Up @@ -110,6 +112,13 @@ def sparse_windowed_scaled_dot_product_self_attention(
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
elif ATTN == 'flash_attn':
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
elif ATTN == 'sdpa':
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
q = q.permute(0, 2, 1, 3) # [B, H, N, C]
k = k.permute(0, 2, 1, 3) # [B, H, N, C]
v = v.permute(0, 2, 1, 3) # [B, H, N, C]
out = F.scaled_dot_product_attention(q, k, v) # [B, H, N, C]
out = out.permute(0, 2, 1, 3) # [B, N, H, C]
else:
raise ValueError(f"Unknown attention module: {ATTN}")
out = out.reshape(B * N, H, C) # [M, H, C]
Expand All @@ -125,6 +134,22 @@ def sparse_windowed_scaled_dot_product_self_attention(
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
.to(qkv.device).int()
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
elif ATTN == 'sdpa':
# Per-window SDPA: ragged window sizes, no varlen SDPA kernel, so
# attend within each window separately. Backend-agnostic (any GPU).
starts = torch.tensor([0] + list(seq_lens), device=qkv.device).cumsum(0).tolist()
outs = []
for i in range(len(seq_lens)):
seg = qkv_feats[starts[i]:starts[i + 1]] # [Si, 3, H, C]
qi, ki, vi = seg.unbind(dim=1) # each [Si, H, C]
qi = qi.transpose(0, 1).unsqueeze(0) # [1, H, Si, C]
ki = ki.transpose(0, 1).unsqueeze(0)
vi = vi.transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(qi, ki, vi) # [1, H, Si, C]
outs.append(oi.squeeze(0).transpose(0, 1)) # [Si, H, C]
out = torch.cat(outs, dim=0) # [M, H, C]
else:
raise ValueError(f"Unknown attention module: {ATTN}")

out = out[bwd_indices] # [T, H, C]

Expand Down