Skip to content

sage-attention: avoid torch custom_op for sm100 and add benchmark #452

Open
drbh wants to merge 1 commit intomainfrom
update-sage-attn-ops
Open

sage-attention: avoid torch custom_op for sm100 and add benchmark #452
drbh wants to merge 1 commit intomainfrom
update-sage-attn-ops

Conversation

@drbh
Copy link
Collaborator

@drbh drbh commented Mar 6, 2026

This PR

  • avoids the @torch.library.custom_op in sm100_compile.py similar to the sm90_compile.py file
  • adds a benchmark file
  • adds a readme example for simple validation script

@drbh drbh requested a review from danieldk as a code owner March 6, 2026 18:15
@drbh
Copy link
Collaborator Author

drbh commented Mar 6, 2026

note these changes may resolve the double registration issue seen on B200's

test with

# /// script
# dependencies = [
#   "numpy",
#   "torch",
#   "kernels",
# ]
# ///
import torch
from kernels import get_kernel

torch.manual_seed(42)
sage_attention = get_kernel("drbh/sage-attn-test", version=2)

device = "cuda"
B, H, L, D = 1, 8, 256, 64
q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)

out = sage_attention.sageattn3_blackwell(q, k, v)
print(f"sageattn output shape: {out.shape}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant