Add grouped linear layer with strided BMM optimization#263
Add grouped linear layer with strided BMM optimization#263z52527 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
|
@z52527 , |
Greptile OverviewGreptile SummaryThis PR adds a high-performance grouped linear layer implementation using strided batched matrix multiplication (BMM). The optimization achieves 1.46x forward speedup and 1.41x forward+backward speedup on H100 by fusing multiple group-wise GEMM operations into a single cuBLAS kernel call. Key changes:
Technical approach: Critical issue: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Input as Input Tensor<br/>(B*N, D_in)
participant Reshape as Reshape Layer
participant BMM as Strided BMM
participant Triton as Triton Kernel<br/>(SiLU*Up)
participant Output as Output Tensor<br/>(B*N, D_out)
Note over Input,Output: Grouped Linear (Simple)
Input->>Reshape: (B*N, D_in)
Reshape->>BMM: (B, N, D_in)
BMM->>BMM: x @ W[n] for each group
Note over BMM: Single cuBLAS kernel<br/>with strided output
BMM->>Output: (B*N, D_out)
Note over Input,Output: Grouped MLP (Gated)
Input->>Reshape: (B*N, D_in)
Reshape->>BMM: (B, N, D_in)
par Parallel Gate & Up Projections
BMM->>BMM: gate = x @ gate_W[n]
and
BMM->>BMM: up = x @ up_W[n]
end
BMM->>Triton: gate, up tensors
Triton->>Triton: silu(gate) * up
Note over Triton: Fused activation<br/>in single kernel
Triton->>BMM: hidden
BMM->>BMM: output = hidden @ down_W[n]
BMM->>Output: (B*N, D_out)
Last reviewed commit: e2bcf47 |
| sys.path.insert( | ||
| 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" |
There was a problem hiding this comment.
Hardcoded absolute path to user's home directory will break for other users or environments.
| sys.path.insert( | |
| 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" | |
| # Calculate relative path from this file to examples/hstu | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../hstu")) |
| sys.path.insert( | ||
| 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" |
There was a problem hiding this comment.
Hardcoded absolute path to user's home directory will break for other users or environments.
| sys.path.insert( | |
| 0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu" | |
| # Calculate relative path from this file to examples/hstu | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../hstu")) |
Problem
Apply num_groups different linear transformations to corresponding slices of input:
Reference Implementation
The straightforward approach uses a loop over groups:
Optimized Implementation
Use torch.bmm with strided output to fuse all GEMMs into one kernel:
Key feature: cuBLAS strided batched GEMM supports strided output via ldc/strideC parameters, allowing direct write to the transposed memory layout.
Performance Results
Config: batch_size=2560, num_groups=12, input_dim=1024, output_dim=3072, dtype=bf16
Device: NVIDIA H100
Device: NVIDIA A100