Skip to content

Add grouped linear layer with strided BMM optimization#263

Open
z52527 wants to merge 8 commits intoNVIDIA:mainfrom
z52527:fea-optimized-strided-BMM
Open

Add grouped linear layer with strided BMM optimization#263
z52527 wants to merge 8 commits intoNVIDIA:mainfrom
z52527:fea-optimized-strided-BMM

Conversation

@z52527
Copy link
Collaborator

@z52527 z52527 commented Jan 6, 2026

Problem

Apply num_groups different linear transformations to corresponding slices of input:

Input:  x of shape (B * num_groups, input_dim)
Output: y of shape (B * num_groups, output_dim)

For each group n: y[b, n, :] = x[b, n, :] @ W[n, :, :]

Reference Implementation

The straightforward approach uses a loop over groups:

x = x.reshape(B, num_groups, D_in)
x_split = torch.split(x, 1, dim=1)

out_list = []
for i in range(num_groups):
    x_i = x_split[i].squeeze(1)           # (B, D_in)
    out_i = linear_layers[i](x_i)         # (B, D_out)
    out_list.append(out_i)

output = torch.stack(out_list, dim=1).reshape(-1, D_out)

Optimized Implementation

Use torch.bmm with strided output to fuse all GEMMs into one kernel:

x = x.reshape(B, num_groups, D_in)
output = torch.empty(B, num_groups, D_out, ...)   # pre-allocate final layout
torch.bmm(x.permute(1,0,2), weight,
          out=output.permute(1,0,2))              # cuBLAS writes to strided memory
return output.view(-1, D_out)                     # O(1) view, no copy

Key feature: cuBLAS strided batched GEMM supports strided output via ldc/strideC parameters, allowing direct write to the transposed memory layout.

Performance Results

Config: batch_size=2560, num_groups=12, input_dim=1024, output_dim=3072, dtype=bf16
Device: NVIDIA H100

Speedup
Forward 1.46x
Forward + Backward 1.41x

Device: NVIDIA A100

Speedup TFLOPS
Forward 1.67x 246.7
Forward + Backward 1.34x 238.0

@z52527 z52527 self-assigned this Jan 6, 2026
@JacoCheung
Copy link
Collaborator

JacoCheung commented Jan 7, 2026

@z52527 ,
Could you generalize the BmmImpl such that it could handle the activation of either [batch_count, batch_size, input_dim] or [batch_size, batch_count, input_dim]? Even though the input is [batch_count*batch_size, input_dim], your impl assumes that input is [batch_size, batch_count, input_dim].

@greptile-apps
Copy link

greptile-apps bot commented Feb 13, 2026

Greptile Overview

Greptile Summary

This PR adds a high-performance grouped linear layer implementation using strided batched matrix multiplication (BMM). The optimization achieves 1.46x forward speedup and 1.41x forward+backward speedup on H100 by fusing multiple group-wise GEMM operations into a single cuBLAS kernel call.

Key changes:

  • GroupedLinear_example.py: Standalone example demonstrating the strided BMM optimization for grouped linear layers with comprehensive benchmarking and correctness verification
  • grouped_mlp_customop.py: PyTorch custom ops implementation with Triton kernels for SwiGLU activation, including strided_bmm and silu_mul ops
  • GroupedMLP_example.py: Extended benchmark comparing reference implementation against the optimized approach with detailed performance metrics

Technical approach:
The optimization leverages cuBLAS's support for strided output tensors, allowing direct writes to transposed memory layouts without additional copy operations. The grouped MLP implementation composes three strided BMMs (gate, up, down projections) with a fused Triton kernel for the SiLU activation.

Critical issue:
Two files contain hardcoded absolute paths (/home/scratch.runchuz_gpu/repos-github/...) that must be replaced with relative paths before merging.

Confidence Score: 4/5

  • Safe to merge after fixing the hardcoded path issues in two files
  • The implementation is solid with proper correctness checks, benchmarking, and well-documented code. The main issue is hardcoded absolute paths that will break for other users. This is a critical but easy fix.
  • grouped_mlp_customop.py and GroupedMLP_example.py need the hardcoded path fixed before deployment

Important Files Changed

Filename Overview
examples/commons/ops/GroupedLinear_example.py New example demonstrating grouped linear layer with strided BMM optimization. Clean implementation with proper benchmarking and correctness checks.
examples/commons/ops/grouped_mlp_customop.py Custom PyTorch ops for grouped MLP with Triton kernels. Contains hardcoded path issue that will break for other users.
examples/commons/ops/GroupedMLP_example.py Benchmark comparing reference vs optimized grouped MLP implementation. Contains hardcoded path issue that will break for other users.

Sequence Diagram

sequenceDiagram
    participant Input as Input Tensor<br/>(B*N, D_in)
    participant Reshape as Reshape Layer
    participant BMM as Strided BMM
    participant Triton as Triton Kernel<br/>(SiLU*Up)
    participant Output as Output Tensor<br/>(B*N, D_out)

    Note over Input,Output: Grouped Linear (Simple)
    Input->>Reshape: (B*N, D_in)
    Reshape->>BMM: (B, N, D_in)
    BMM->>BMM: x @ W[n] for each group
    Note over BMM: Single cuBLAS kernel<br/>with strided output
    BMM->>Output: (B*N, D_out)

    Note over Input,Output: Grouped MLP (Gated)
    Input->>Reshape: (B*N, D_in)
    Reshape->>BMM: (B, N, D_in)
    
    par Parallel Gate & Up Projections
        BMM->>BMM: gate = x @ gate_W[n]
    and
        BMM->>BMM: up = x @ up_W[n]
    end
    
    BMM->>Triton: gate, up tensors
    Triton->>Triton: silu(gate) * up
    Note over Triton: Fused activation<br/>in single kernel
    Triton->>BMM: hidden
    BMM->>BMM: output = hidden @ down_W[n]
    BMM->>Output: (B*N, D_out)
Loading

Last reviewed commit: e2bcf47

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +15 to +16
sys.path.insert(
0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded absolute path to user's home directory will break for other users or environments.

Suggested change
sys.path.insert(
0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu"
# Calculate relative path from this file to examples/hstu
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../hstu"))

Comment on lines +24 to +25
sys.path.insert(
0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded absolute path to user's home directory will break for other users or environments.

Suggested change
sys.path.insert(
0, "/home/scratch.runchuz_gpu/repos-github/recsys-examples/examples/hstu"
# Calculate relative path from this file to examples/hstu
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../hstu"))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants