Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.program._program import _update_exported_program_graph_module
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch.export.passes import move_to_device_pass
from torch.fx.passes.infra.pass_base import PassResult


class COMPILE_SPEC_KEYS(Enum):
Expand Down Expand Up @@ -156,7 +158,40 @@ def preprocess(
# Apply custom backend-specific passes
custom_passes = cls.get_custom_passes(compile_specs)
for custom_pass in custom_passes:
custom_pass(device_edge_program.graph_module)
result = custom_pass(device_edge_program.graph_module)
# Handle passes that return PassResult with a new graph_module
if isinstance(result, PassResult) and result.modified:
# Use a permissive verifier that allows all operator types including
# edge ops and custom triton ops. The default verifier only allows
# torch._ops.OpOverload and HigherOrderOperator, but edge dialect
# uses EdgeOpOverload, triton ops may use OpOverloadPacket or CustomOpDef.
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._export.verifier import Verifier
from torch._library.custom_ops import CustomOpDef
from torch._ops import (
HigherOrderOperator,
OpOverload,
OpOverloadPacket,
)

class _PermissiveVerifier(Verifier):
dialect = "EDGE"

def allowed_op_types(self):
return (
OpOverload,
OpOverloadPacket,
HigherOrderOperator,
EdgeOpOverload,
CustomOpDef,
)

def check_valid_op(self, op):
pass # Allow all ops

device_edge_program = _update_exported_program_graph_module(
device_edge_program, result.graph_module, override_verifiers=[_PermissiveVerifier]
)

# Run decompositions if any
if decomposition_table:
Expand Down
23 changes: 21 additions & 2 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@

import torch
from executorch.backends.aoti.aoti_backend import AotiBackend
from executorch.backends.cuda.passes import FuseInt4WeightOnlyQuantMatmulPass
from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)
from executorch.exir._warnings import experimental
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from executorch.exir.backend.backend_details import BackendDetails
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.decomposition import conv1d_to_conv2d
Expand Down Expand Up @@ -50,12 +52,26 @@ def get_decomposition_table(cls) -> Dict[Any, Any]:
@classmethod
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
"""
Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass.
Return CUDA-specific passes.

Passes include:
- CSEPass: Common subexpression elimination to merge duplicate preprocessing chains
- FuseInt4WeightOnlyQuantMatmulPass: Fuses INT4 matmul ops sharing the same input
- ReplaceEdgeOpWithTritonOpPass: Replaces edge ops with Triton kernels (optional)

The Triton kernel replacement behavior can be controlled via compile_specs:
- triton_kernel_mode="ON": Always use Triton kernels
- triton_kernel_mode="OFF": Never use Triton kernels and fallback to other implementations like cuda or decomposed operator.
"""
# Start with CSE pass to merge duplicate preprocessing chains
# This enables Int4 fusion to group operations with identical preprocessing
passes: List[typing.Any] = [CSEPass()]

# Fuse INT4 weight-only quantized matmul operations
# Reduces kernel launch overhead by fusing Q/K/V and Gate/Up projections
# REQUIRES: CSE pass to have run first (to merge preprocessing)
passes.append(FuseInt4WeightOnlyQuantMatmulPass())

# Parse compile_specs for triton_kernel_mode
triton_kernel_mode = "ON" # Default mode
for spec in compile_specs:
Expand All @@ -68,7 +84,10 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
)
triton_kernel_mode = mode

return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
if triton_kernel_mode == "ON":
passes.append(ReplaceEdgeOpWithTritonOpPass())

return passes

@classmethod
def get_aoti_compile_options(
Expand Down
11 changes: 11 additions & 0 deletions backends/cuda/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""CUDA backend optimization passes."""

from .fuse_int4_quant_matmul import FuseInt4WeightOnlyQuantMatmulPass

__all__ = ["FuseInt4WeightOnlyQuantMatmulPass"]
Loading
Loading