Skip to content

Commit 59621ef

Browse files
committed
Fused INT4 weight-only quantized matmul pass for CUDA backend
Add fusion pass that combines multiple int4pack_mm operations sharing the same input tensor into a single fused operation, reducing kernel launch overhead for LLM attention (Q/K/V) and MLP (Gate/Up) projections. Key changes: - Add FuseInt4WeightOnlyQuantMatmulPass in backends/cuda/passes/ - Add CSEPass before fusion to merge duplicate preprocessing chains - Fix AotiBackend.preprocess to properly handle PassResult from passes that return new graph_modules (using _update_exported_program_graph_module) - Add comprehensive tests for the fusion pass
1 parent 0e13ae6 commit 59621ef

File tree

6 files changed

+1563
-3
lines changed

6 files changed

+1563
-3
lines changed

backends/aoti/aoti_backend.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from executorch.exir._warnings import experimental
2020
from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult
2121
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from executorch.exir.program._program import _update_exported_program_graph_module
2223
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
2324
from torch.export.passes import move_to_device_pass
25+
from torch.fx.passes.infra.pass_base import PassResult
2426

2527

2628
class COMPILE_SPEC_KEYS(Enum):
@@ -156,7 +158,40 @@ def preprocess(
156158
# Apply custom backend-specific passes
157159
custom_passes = cls.get_custom_passes(compile_specs)
158160
for custom_pass in custom_passes:
159-
custom_pass(device_edge_program.graph_module)
161+
result = custom_pass(device_edge_program.graph_module)
162+
# Handle passes that return PassResult with a new graph_module
163+
if isinstance(result, PassResult) and result.modified:
164+
# Use a permissive verifier that allows all operator types including
165+
# edge ops and custom triton ops. The default verifier only allows
166+
# torch._ops.OpOverload and HigherOrderOperator, but edge dialect
167+
# uses EdgeOpOverload, triton ops may use OpOverloadPacket or CustomOpDef.
168+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
169+
from torch._export.verifier import Verifier
170+
from torch._library.custom_ops import CustomOpDef
171+
from torch._ops import (
172+
HigherOrderOperator,
173+
OpOverload,
174+
OpOverloadPacket,
175+
)
176+
177+
class _PermissiveVerifier(Verifier):
178+
dialect = "EDGE"
179+
180+
def allowed_op_types(self):
181+
return (
182+
OpOverload,
183+
OpOverloadPacket,
184+
HigherOrderOperator,
185+
EdgeOpOverload,
186+
CustomOpDef,
187+
)
188+
189+
def check_valid_op(self, op):
190+
pass # Allow all ops
191+
192+
device_edge_program = _update_exported_program_graph_module(
193+
device_edge_program, result.graph_module, override_verifiers=[_PermissiveVerifier]
194+
)
160195

161196
# Run decompositions if any
162197
if decomposition_table:

backends/cuda/cuda_backend.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
import torch
1212
from executorch.backends.aoti.aoti_backend import AotiBackend
13+
from executorch.backends.cuda.passes import FuseInt4WeightOnlyQuantMatmulPass
1314
from executorch.backends.cuda.triton.replacement_pass import (
1415
ReplaceEdgeOpWithTritonOpPass,
1516
)
1617
from executorch.exir._warnings import experimental
18+
from torch.fx.passes.dialect.common.cse_pass import CSEPass
1719
from executorch.exir.backend.backend_details import BackendDetails
1820
from executorch.exir.backend.compile_spec_schema import CompileSpec
1921
from torch._inductor.decomposition import conv1d_to_conv2d
@@ -50,12 +52,26 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
5052
@classmethod
5153
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
5254
"""
53-
Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass.
55+
Return CUDA-specific passes.
56+
57+
Passes include:
58+
- CSEPass: Common subexpression elimination to merge duplicate preprocessing chains
59+
- FuseInt4WeightOnlyQuantMatmulPass: Fuses INT4 matmul ops sharing the same input
60+
- ReplaceEdgeOpWithTritonOpPass: Replaces edge ops with Triton kernels (optional)
5461
5562
The Triton kernel replacement behavior can be controlled via compile_specs:
5663
- triton_kernel_mode="ON": Always use Triton kernels
5764
- triton_kernel_mode="OFF": Never use Triton kernels and fallback to other implementations like cuda or decomposed operator.
5865
"""
66+
# Start with CSE pass to merge duplicate preprocessing chains
67+
# This enables Int4 fusion to group operations with identical preprocessing
68+
passes: List[typing.Any] = [CSEPass()]
69+
70+
# Fuse INT4 weight-only quantized matmul operations
71+
# Reduces kernel launch overhead by fusing Q/K/V and Gate/Up projections
72+
# REQUIRES: CSE pass to have run first (to merge preprocessing)
73+
passes.append(FuseInt4WeightOnlyQuantMatmulPass())
74+
5975
# Parse compile_specs for triton_kernel_mode
6076
triton_kernel_mode = "ON" # Default mode
6177
for spec in compile_specs:
@@ -68,7 +84,10 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
6884
)
6985
triton_kernel_mode = mode
7086

71-
return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
87+
if triton_kernel_mode == "ON":
88+
passes.append(ReplaceEdgeOpWithTritonOpPass())
89+
90+
return passes
7291

7392
@classmethod
7493
def get_aoti_compile_options(

backends/cuda/passes/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""CUDA backend optimization passes."""
8+
9+
from .fuse_int4_quant_matmul import FuseInt4WeightOnlyQuantMatmulPass
10+
11+
__all__ = ["FuseInt4WeightOnlyQuantMatmulPass"]

0 commit comments

Comments
 (0)