Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 161 additions & 95 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from torch.fx.node import Argument
from executorch.exir.pass_base import ExportPass, PassResult

# A map to represent ops that:
# (a) are functionally equivalent; and
Expand Down Expand Up @@ -1017,131 +1016,198 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int:
return dim


class ExportPassWithTransposeHelper(ExportPass):
def transpose_dims(
self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int
) -> ProxyValue:
"""Helper function to transpose dims of a `proxy` with given `meta`."""
shape = proxy.data.shape
@register_cadence_pass(CadencePassAttribute(opt_level=3))
class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface):
"""
Replace NCHW convolutions with NHWC (channel-last) convolutions by adding
transpose operations before and after the convolution.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.cadence.conv1d.default,
exir_ops.edge.cadence.conv2d.default,
exir_ops.edge.cadence.conv3d.default,
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
]

def _transpose_dims(
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
) -> torch.fx.Node:
"""Helper function to transpose dims of a node."""
shape = node.meta["val"].shape
dim0, dim1 = (
canonicalize_transposed_dim(dim0, shape),
canonicalize_transposed_dim(dim1, shape),
)
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
return super().call_operator(
exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta
transpose_node = graph.call_function(
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
)


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
shape = proxy.to_tensor().shape
transpose_node.meta = node.meta
return transpose_node

def _change_nchw_to_nhwc(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert NCHW format to NHWC format."""
shape = node.meta["val"].shape
if len(shape) == 3:
return self.transpose_dims(proxy, meta, 1, -1)
return self._transpose_dims(graph, node, 1, -1)
indices = list(range(len(shape)))
permute_indices = [indices[0]] + indices[2:] + [indices[1]]
return super().call_operator(
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
)

def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
shape = proxy.to_tensor().shape
permute_node.meta = node.meta
return permute_node

def _change_nhwc_to_nchw(
self, graph: torch.fx.Graph, node: torch.fx.Node
) -> torch.fx.Node:
"""Convert NHWC format to NCHW format."""
shape = node.meta["val"].shape
if len(shape) == 3:
return self.transpose_dims(proxy, meta, 1, -1)
return self._transpose_dims(graph, node, 1, -1)
indices = list(range(len(shape)))
permute_indices = [indices[0], indices[-1]] + indices[1:-1]
return super().call_operator(
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
)
permute_node.meta = node.meta
return permute_node

def call_operator(
self,
op,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.cadence.conv1d.default,
exir_ops.edge.cadence.conv2d.default,
exir_ops.edge.cadence.conv3d.default,
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
}:
return super().call_operator(op, args, kwargs, meta)

quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
assert isinstance(node.target, EdgeOpOverload)
quantized_op = (
node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
)

if not quantized_op and len(args) == 8 and args[-1] is True:
# Already in NHWC layout.
return super().call_operator(op, args, kwargs, meta)
# Check if already in NHWC layout
if not quantized_op and len(node.args) == 8 and node.args[-1] is True:
return False

# Determine the new op target
if quantized_op:
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
else:
# Determine if 1D or 2D convolution based on op
new_op = op
new_op = node.target

graph = node.graph

input_proxy = cast(ProxyValue, args[0])
weight_proxy = cast(ProxyValue, args[1])
input_proxy = self.change_nchw_to_nhwc(input_proxy, meta)
weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta)
# Get input and weight nodes
input_node = cast(torch.fx.Node, node.args[0])
weight_node = cast(torch.fx.Node, node.args[1])

# Non-quantized ops still need to set the last optional argument to True.
channel_last_arg = [] if quantized_op else [True]
# Insert transpose operations before the node
with graph.inserting_before(node):
# Convert input from NCHW to NHWC
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
# Convert weight from NCHW to NHWC
weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node)

new_args = (
# Transposed input/weights.
(input_proxy, weight_proxy)
# All other args (bias, quant params, etc)
+ tuple(args[2:])
+ tuple(channel_last_arg)
)
output_proxy = super().call_operator(new_op, new_args, kwargs, meta)
nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta)
return nchw_proxy
# Non-quantized ops need to set the last optional argument to True
channel_last_arg = [] if quantized_op else [True]

# Create new args with transposed input/weights
new_args = (
(input_nhwc, weight_nhwc)
+ tuple(node.args[2:])
+ tuple(channel_last_arg)
)

# Create the new conv operation
new_conv = graph.call_function(new_op, new_args, node.kwargs)
new_conv.meta = node.meta

# Convert output back from NHWC to NCHW
nchw_output = self._change_nhwc_to_nchw(graph, new_conv)

# Replace all uses with the final output
node.replace_all_uses_with(nchw_output)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=3))
class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper):
def call_operator(
self,
op,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
"""
Make the slice/cat dimension the outermost dimension by adding transpose
operations before and after the slice/cat operation.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.slice_copy.Tensor,
}:
return super().call_operator(op, args, kwargs, meta)
dim = cast(int, args[1]) if len(args) > 1 else 0
output_shape = meta["val"].shape
]

def _transpose_dims(
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
) -> torch.fx.Node:
"""Helper function to transpose dims of a node."""
shape = node.meta["val"].shape
dim0, dim1 = (
canonicalize_transposed_dim(dim0, shape),
canonicalize_transposed_dim(dim1, shape),
)
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
transpose_node = graph.call_function(
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
)
transpose_node.meta = node.meta
return transpose_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the dimension argument
dim = cast(int, node.args[1]) if len(node.args) > 1 else 0
output_shape = node.meta["val"].shape

# Canonicalize dim to be positive
if dim < 0:
# Keep dim positive.
dim += len(output_shape)

# Not needed if dim is already outermost or all dims before it are 1
if dim == 0 or math.prod(output_shape[:dim]) == 1:
# Not needed if dim is already outermost or all dims before it are 1.
return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta)

if op == exir_ops.edge.aten.slice_copy.Tensor:
# Transpose -> slice.
slice_args = (
self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0),
0,
) + args[2:]
new_op = super().call_operator(op, slice_args, kwargs, meta)
else:
# (Transpose input0, Transpose input1, ...) -> cat.
cat_in_tensors = [
self.transpose_dims(t, meta, dim, 0)
for t in cast(list[ProxyValue], args[0])
]
new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta)
# slice/cat -> transpose.
return self.transpose_dims(new_op, meta, 0, dim)
return False

graph = node.graph

with graph.inserting_before(node):
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
# Transpose input -> slice with dim=0 -> transpose back
input_node = cast(torch.fx.Node, node.args[0])
transposed_input = self._transpose_dims(graph, input_node, dim, 0)

# Create slice operation with dim=0
slice_args = (transposed_input, 0) + node.args[2:]
sliced = graph.call_function(
exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs
)
sliced.meta = node.meta

# Transpose back
result = self._transpose_dims(graph, sliced, 0, dim)
else:
# Cat operation: transpose all inputs -> cat with dim=0 -> transpose back
cat_inputs = cast(list[torch.fx.Node], node.args[0])
transposed_inputs = [
self._transpose_dims(graph, t, dim, 0) for t in cat_inputs
]

# Create cat operation with dim=0
catted = graph.call_function(
exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs
)
catted.meta = node.meta

# Transpose back
result = self._transpose_dims(graph, catted, 0, dim)

# Replace all uses with the final result
node.replace_all_uses_with(result)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=2))
Expand Down
Loading
Loading