Skip to content

Commit 7bb909c

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update ReplaceConvWithChannelLastConvPass and MakeSliceAndCatDimOutermostPass to correctly set modified bit
Summary: As titled Differential Revision: D87880891
1 parent 874cb34 commit 7bb909c

File tree

2 files changed

+239
-110
lines changed

2 files changed

+239
-110
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 160 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@
4040
)
4141
from executorch.exir.dialects._ops import ops as exir_ops
4242
from executorch.exir.dialects.edge._ops import EdgeOpOverload
43-
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
44-
from torch.fx.node import Argument
43+
from executorch.exir.pass_base import ExportPass, PassResult
4544

4645
# A map to represent ops that:
4746
# (a) are functionally equivalent; and
@@ -1023,131 +1022,197 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int:
10231022
return dim
10241023

10251024

1026-
class ExportPassWithTransposeHelper(ExportPass):
1027-
def transpose_dims(
1028-
self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int
1029-
) -> ProxyValue:
1030-
"""Helper function to transpose dims of a `proxy` with given `meta`."""
1031-
shape = proxy.data.shape
1025+
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1026+
class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface):
1027+
"""
1028+
Replace NCHW convolutions with NHWC (channel-last) convolutions by adding
1029+
transpose operations before and after the convolution.
1030+
"""
1031+
1032+
@property
1033+
def targets(self) -> list[EdgeOpOverload]:
1034+
return [
1035+
exir_ops.edge.cadence.conv1d.default,
1036+
exir_ops.edge.cadence.conv2d.default,
1037+
exir_ops.edge.cadence.conv3d.default,
1038+
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
1039+
]
1040+
1041+
def _transpose_dims(
1042+
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
1043+
) -> torch.fx.Node:
1044+
"""Helper function to transpose dims of a node."""
1045+
shape = node.meta["val"].shape
10321046
dim0, dim1 = (
10331047
canonicalize_transposed_dim(dim0, shape),
10341048
canonicalize_transposed_dim(dim1, shape),
10351049
)
10361050
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1037-
return super().call_operator(
1038-
exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta
1051+
transpose_node = graph.call_function(
1052+
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
10391053
)
1040-
1041-
1042-
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1043-
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
1044-
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1045-
shape = proxy.to_tensor().shape
1054+
transpose_node.meta = node.meta
1055+
return transpose_node
1056+
1057+
def _change_nchw_to_nhwc(
1058+
self, graph: torch.fx.Graph, node: torch.fx.Node
1059+
) -> torch.fx.Node:
1060+
"""Convert NCHW format to NHWC format."""
1061+
shape = node.meta["val"].shape
10461062
if len(shape) == 3:
1047-
return self.transpose_dims(proxy, meta, 1, -1)
1063+
return self._transpose_dims(graph, node, 1, -1)
10481064
indices = list(range(len(shape)))
10491065
permute_indices = [indices[0]] + indices[2:] + [indices[1]]
1050-
return super().call_operator(
1051-
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1066+
permute_node = graph.call_function(
1067+
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
10521068
)
1053-
1054-
def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1055-
shape = proxy.to_tensor().shape
1069+
permute_node.meta = node.meta
1070+
return permute_node
1071+
1072+
def _change_nhwc_to_nchw(
1073+
self, graph: torch.fx.Graph, node: torch.fx.Node
1074+
) -> torch.fx.Node:
1075+
"""Convert NHWC format to NCHW format."""
1076+
shape = node.meta["val"].shape
10561077
if len(shape) == 3:
1057-
return self.transpose_dims(proxy, meta, 1, -1)
1078+
return self._transpose_dims(graph, node, 1, -1)
10581079
indices = list(range(len(shape)))
10591080
permute_indices = [indices[0], indices[-1]] + indices[1:-1]
1060-
return super().call_operator(
1061-
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1081+
permute_node = graph.call_function(
1082+
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
10621083
)
1084+
permute_node.meta = node.meta
1085+
return permute_node
10631086

1064-
def call_operator(
1065-
self,
1066-
op,
1067-
args: tuple[Argument, ...],
1068-
kwargs: dict[str, Argument],
1069-
meta: NodeMetadata,
1070-
) -> ProxyValue:
1071-
if op not in {
1072-
exir_ops.edge.cadence.conv1d.default,
1073-
exir_ops.edge.cadence.conv2d.default,
1074-
exir_ops.edge.cadence.conv3d.default,
1075-
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
1076-
}:
1077-
return super().call_operator(op, args, kwargs, meta)
1078-
1079-
quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
1087+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1088+
assert isinstance(node.target, EdgeOpOverload)
1089+
quantized_op = node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
10801090

1081-
if not quantized_op and len(args) == 8 and args[-1] is True:
1082-
# Already in NHWC layout.
1083-
return super().call_operator(op, args, kwargs, meta)
1091+
# Check if already in NHWC layout
1092+
if not quantized_op and len(node.args) == 8 and node.args[-1] is True:
1093+
return False
10841094

1095+
# Determine the new op target
10851096
if quantized_op:
10861097
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
10871098
else:
1088-
# Determine if 1D or 2D convolution based on op
1089-
new_op = op
1099+
new_op = node.target
10901100

1091-
input_proxy = cast(ProxyValue, args[0])
1092-
weight_proxy = cast(ProxyValue, args[1])
1093-
input_proxy = self.change_nchw_to_nhwc(input_proxy, meta)
1094-
weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta)
1101+
graph = node.graph
10951102

1096-
# Non-quantized ops still need to set the last optional argument to True.
1097-
channel_last_arg = [] if quantized_op else [True]
1103+
# Get input and weight nodes
1104+
input_node = cast(torch.fx.Node, node.args[0])
1105+
weight_node = cast(torch.fx.Node, node.args[1])
10981106

1099-
new_args = (
1100-
# Transposed input/weights.
1101-
(input_proxy, weight_proxy)
1102-
# All other args (bias, quant params, etc)
1103-
+ tuple(args[2:])
1104-
+ tuple(channel_last_arg)
1105-
)
1106-
output_proxy = super().call_operator(new_op, new_args, kwargs, meta)
1107-
nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta)
1108-
return nchw_proxy
1107+
# Insert transpose operations before the node
1108+
with graph.inserting_before(node):
1109+
# Convert input from NCHW to NHWC
1110+
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
1111+
# Convert weight from NCHW to NHWC
1112+
weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node)
1113+
1114+
# Non-quantized ops need to set the last optional argument to True
1115+
channel_last_arg = [] if quantized_op else [True]
1116+
1117+
# Create new args with transposed input/weights
1118+
new_args = (
1119+
(input_nhwc, weight_nhwc)
1120+
+ tuple(node.args[2:])
1121+
+ tuple(channel_last_arg)
1122+
)
1123+
1124+
# Create the new conv operation
1125+
new_conv = graph.call_function(new_op, new_args, node.kwargs)
1126+
new_conv.meta = node.meta
1127+
1128+
# Convert output back from NHWC to NCHW
1129+
nchw_output = self._change_nhwc_to_nchw(graph, new_conv)
1130+
1131+
# Replace all uses with the final output
1132+
node.replace_all_uses_with(nchw_output)
1133+
return True
11091134

11101135

11111136
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1112-
class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper):
1113-
def call_operator(
1114-
self,
1115-
op,
1116-
args: tuple[Argument, ...],
1117-
kwargs: dict[str, Argument],
1118-
meta: NodeMetadata,
1119-
) -> ProxyValue:
1120-
if op not in {
1137+
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
1138+
"""
1139+
Make the slice/cat dimension the outermost dimension by adding transpose
1140+
operations before and after the slice/cat operation.
1141+
"""
1142+
1143+
@property
1144+
def targets(self) -> list[EdgeOpOverload]:
1145+
return [
11211146
exir_ops.edge.aten.cat.default,
11221147
exir_ops.edge.aten.slice_copy.Tensor,
1123-
}:
1124-
return super().call_operator(op, args, kwargs, meta)
1125-
dim = cast(int, args[1]) if len(args) > 1 else 0
1126-
output_shape = meta["val"].shape
1148+
]
1149+
1150+
def _transpose_dims(
1151+
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
1152+
) -> torch.fx.Node:
1153+
"""Helper function to transpose dims of a node."""
1154+
shape = node.meta["val"].shape
1155+
dim0, dim1 = (
1156+
canonicalize_transposed_dim(dim0, shape),
1157+
canonicalize_transposed_dim(dim1, shape),
1158+
)
1159+
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1160+
transpose_node = graph.call_function(
1161+
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1162+
)
1163+
transpose_node.meta = node.meta
1164+
return transpose_node
1165+
1166+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1167+
# Get the dimension argument
1168+
dim = cast(int, node.args[1]) if len(node.args) > 1 else 0
1169+
output_shape = node.meta["val"].shape
1170+
1171+
# Canonicalize dim to be positive
11271172
if dim < 0:
1128-
# Keep dim positive.
11291173
dim += len(output_shape)
11301174

1175+
# Not needed if dim is already outermost or all dims before it are 1
11311176
if dim == 0 or math.prod(output_shape[:dim]) == 1:
1132-
# Not needed if dim is already outermost or all dims before it are 1.
1133-
return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta)
1134-
1135-
if op == exir_ops.edge.aten.slice_copy.Tensor:
1136-
# Transpose -> slice.
1137-
slice_args = (
1138-
self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0),
1139-
0,
1140-
) + args[2:]
1141-
new_op = super().call_operator(op, slice_args, kwargs, meta)
1142-
else:
1143-
# (Transpose input0, Transpose input1, ...) -> cat.
1144-
cat_in_tensors = [
1145-
self.transpose_dims(t, meta, dim, 0)
1146-
for t in cast(list[ProxyValue], args[0])
1147-
]
1148-
new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta)
1149-
# slice/cat -> transpose.
1150-
return self.transpose_dims(new_op, meta, 0, dim)
1177+
return False
1178+
1179+
graph = node.graph
1180+
1181+
with graph.inserting_before(node):
1182+
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
1183+
# Transpose input -> slice with dim=0 -> transpose back
1184+
input_node = cast(torch.fx.Node, node.args[0])
1185+
transposed_input = self._transpose_dims(graph, input_node, dim, 0)
1186+
1187+
# Create slice operation with dim=0
1188+
slice_args = (transposed_input, 0) + node.args[2:]
1189+
sliced = graph.call_function(
1190+
exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs
1191+
)
1192+
sliced.meta = node.meta
1193+
1194+
# Transpose back
1195+
result = self._transpose_dims(graph, sliced, 0, dim)
1196+
else:
1197+
# Cat operation: transpose all inputs -> cat with dim=0 -> transpose back
1198+
cat_inputs = cast(list[torch.fx.Node], node.args[0])
1199+
transposed_inputs = [
1200+
self._transpose_dims(graph, t, dim, 0)
1201+
for t in cat_inputs
1202+
]
1203+
1204+
# Create cat operation with dim=0
1205+
catted = graph.call_function(
1206+
exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs
1207+
)
1208+
catted.meta = node.meta
1209+
1210+
# Transpose back
1211+
result = self._transpose_dims(graph, catted, 0, dim)
1212+
1213+
# Replace all uses with the final result
1214+
node.replace_all_uses_with(result)
1215+
return True
11511216

11521217

11531218
@register_cadence_pass(CadencePassAttribute(opt_level=2))

0 commit comments

Comments
 (0)