Skip to content

Commit 951f8ca

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update ReplaceConvWithChannelLastConvPass and MakeSliceAndCatDimOutermostPass to correctly set modified bit (#16185)
Summary: Update ReplaceConvWithChannelLastConvPass and MakeSliceAndCatDimOutermostPass to use new interface Reviewed By: ethansfng Differential Revision: D87880891
1 parent 1fbf951 commit 951f8ca

File tree

2 files changed

+238
-104
lines changed

2 files changed

+238
-104
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 161 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141
)
4242
from executorch.exir.dialects._ops import ops as exir_ops
4343
from executorch.exir.dialects.edge._ops import EdgeOpOverload
44-
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
45-
from torch.fx.node import Argument
44+
from executorch.exir.pass_base import ExportPass, PassResult
4645

4746
# A map to represent ops that:
4847
# (a) are functionally equivalent; and
@@ -1017,131 +1016,198 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int:
10171016
return dim
10181017

10191018

1020-
class ExportPassWithTransposeHelper(ExportPass):
1021-
def transpose_dims(
1022-
self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int
1023-
) -> ProxyValue:
1024-
"""Helper function to transpose dims of a `proxy` with given `meta`."""
1025-
shape = proxy.data.shape
1019+
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1020+
class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface):
1021+
"""
1022+
Replace NCHW convolutions with NHWC (channel-last) convolutions by adding
1023+
transpose operations before and after the convolution.
1024+
"""
1025+
1026+
@property
1027+
def targets(self) -> list[EdgeOpOverload]:
1028+
return [
1029+
exir_ops.edge.cadence.conv1d.default,
1030+
exir_ops.edge.cadence.conv2d.default,
1031+
exir_ops.edge.cadence.conv3d.default,
1032+
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
1033+
]
1034+
1035+
def _transpose_dims(
1036+
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
1037+
) -> torch.fx.Node:
1038+
"""Helper function to transpose dims of a node."""
1039+
shape = node.meta["val"].shape
10261040
dim0, dim1 = (
10271041
canonicalize_transposed_dim(dim0, shape),
10281042
canonicalize_transposed_dim(dim1, shape),
10291043
)
10301044
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1031-
return super().call_operator(
1032-
exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta
1045+
transpose_node = graph.call_function(
1046+
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
10331047
)
1034-
1035-
1036-
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1037-
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
1038-
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1039-
shape = proxy.to_tensor().shape
1048+
transpose_node.meta = node.meta
1049+
return transpose_node
1050+
1051+
def _change_nchw_to_nhwc(
1052+
self, graph: torch.fx.Graph, node: torch.fx.Node
1053+
) -> torch.fx.Node:
1054+
"""Convert NCHW format to NHWC format."""
1055+
shape = node.meta["val"].shape
10401056
if len(shape) == 3:
1041-
return self.transpose_dims(proxy, meta, 1, -1)
1057+
return self._transpose_dims(graph, node, 1, -1)
10421058
indices = list(range(len(shape)))
10431059
permute_indices = [indices[0]] + indices[2:] + [indices[1]]
1044-
return super().call_operator(
1045-
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1060+
permute_node = graph.call_function(
1061+
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
10461062
)
1047-
1048-
def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1049-
shape = proxy.to_tensor().shape
1063+
permute_node.meta = node.meta
1064+
return permute_node
1065+
1066+
def _change_nhwc_to_nchw(
1067+
self, graph: torch.fx.Graph, node: torch.fx.Node
1068+
) -> torch.fx.Node:
1069+
"""Convert NHWC format to NCHW format."""
1070+
shape = node.meta["val"].shape
10501071
if len(shape) == 3:
1051-
return self.transpose_dims(proxy, meta, 1, -1)
1072+
return self._transpose_dims(graph, node, 1, -1)
10521073
indices = list(range(len(shape)))
10531074
permute_indices = [indices[0], indices[-1]] + indices[1:-1]
1054-
return super().call_operator(
1055-
exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1075+
permute_node = graph.call_function(
1076+
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
10561077
)
1078+
permute_node.meta = node.meta
1079+
return permute_node
10571080

1058-
def call_operator(
1059-
self,
1060-
op,
1061-
args: tuple[Argument, ...],
1062-
kwargs: dict[str, Argument],
1063-
meta: NodeMetadata,
1064-
) -> ProxyValue:
1065-
if op not in {
1066-
exir_ops.edge.cadence.conv1d.default,
1067-
exir_ops.edge.cadence.conv2d.default,
1068-
exir_ops.edge.cadence.conv3d.default,
1069-
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
1070-
}:
1071-
return super().call_operator(op, args, kwargs, meta)
1072-
1073-
quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
1081+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1082+
assert isinstance(node.target, EdgeOpOverload)
1083+
quantized_op = (
1084+
node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
1085+
)
10741086

1075-
if not quantized_op and len(args) == 8 and args[-1] is True:
1076-
# Already in NHWC layout.
1077-
return super().call_operator(op, args, kwargs, meta)
1087+
# Check if already in NHWC layout
1088+
if not quantized_op and len(node.args) == 8 and node.args[-1] is True:
1089+
return False
10781090

1091+
# Determine the new op target
10791092
if quantized_op:
10801093
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
10811094
else:
1082-
# Determine if 1D or 2D convolution based on op
1083-
new_op = op
1095+
new_op = node.target
1096+
1097+
graph = node.graph
10841098

1085-
input_proxy = cast(ProxyValue, args[0])
1086-
weight_proxy = cast(ProxyValue, args[1])
1087-
input_proxy = self.change_nchw_to_nhwc(input_proxy, meta)
1088-
weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta)
1099+
# Get input and weight nodes
1100+
input_node = cast(torch.fx.Node, node.args[0])
1101+
weight_node = cast(torch.fx.Node, node.args[1])
10891102

1090-
# Non-quantized ops still need to set the last optional argument to True.
1091-
channel_last_arg = [] if quantized_op else [True]
1103+
# Insert transpose operations before the node
1104+
with graph.inserting_before(node):
1105+
# Convert input from NCHW to NHWC
1106+
input_nhwc = self._change_nchw_to_nhwc(graph, input_node)
1107+
# Convert weight from NCHW to NHWC
1108+
weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node)
10921109

1093-
new_args = (
1094-
# Transposed input/weights.
1095-
(input_proxy, weight_proxy)
1096-
# All other args (bias, quant params, etc)
1097-
+ tuple(args[2:])
1098-
+ tuple(channel_last_arg)
1099-
)
1100-
output_proxy = super().call_operator(new_op, new_args, kwargs, meta)
1101-
nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta)
1102-
return nchw_proxy
1110+
# Non-quantized ops need to set the last optional argument to True
1111+
channel_last_arg = [] if quantized_op else [True]
1112+
1113+
# Create new args with transposed input/weights
1114+
new_args = (
1115+
(input_nhwc, weight_nhwc)
1116+
+ tuple(node.args[2:])
1117+
+ tuple(channel_last_arg)
1118+
)
1119+
1120+
# Create the new conv operation
1121+
new_conv = graph.call_function(new_op, new_args, node.kwargs)
1122+
new_conv.meta = node.meta
1123+
1124+
# Convert output back from NHWC to NCHW
1125+
nchw_output = self._change_nhwc_to_nchw(graph, new_conv)
1126+
1127+
# Replace all uses with the final output
1128+
node.replace_all_uses_with(nchw_output)
1129+
return True
11031130

11041131

11051132
@register_cadence_pass(CadencePassAttribute(opt_level=3))
1106-
class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper):
1107-
def call_operator(
1108-
self,
1109-
op,
1110-
args: tuple[Argument, ...],
1111-
kwargs: dict[str, Argument],
1112-
meta: NodeMetadata,
1113-
) -> ProxyValue:
1114-
if op not in {
1133+
class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface):
1134+
"""
1135+
Make the slice/cat dimension the outermost dimension by adding transpose
1136+
operations before and after the slice/cat operation.
1137+
"""
1138+
1139+
@property
1140+
def targets(self) -> list[EdgeOpOverload]:
1141+
return [
11151142
exir_ops.edge.aten.cat.default,
11161143
exir_ops.edge.aten.slice_copy.Tensor,
1117-
}:
1118-
return super().call_operator(op, args, kwargs, meta)
1119-
dim = cast(int, args[1]) if len(args) > 1 else 0
1120-
output_shape = meta["val"].shape
1144+
]
1145+
1146+
def _transpose_dims(
1147+
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
1148+
) -> torch.fx.Node:
1149+
"""Helper function to transpose dims of a node."""
1150+
shape = node.meta["val"].shape
1151+
dim0, dim1 = (
1152+
canonicalize_transposed_dim(dim0, shape),
1153+
canonicalize_transposed_dim(dim1, shape),
1154+
)
1155+
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1156+
transpose_node = graph.call_function(
1157+
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1158+
)
1159+
transpose_node.meta = node.meta
1160+
return transpose_node
1161+
1162+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1163+
# Get the dimension argument
1164+
dim = cast(int, node.args[1]) if len(node.args) > 1 else 0
1165+
output_shape = node.meta["val"].shape
1166+
1167+
# Canonicalize dim to be positive
11211168
if dim < 0:
1122-
# Keep dim positive.
11231169
dim += len(output_shape)
11241170

1171+
# Not needed if dim is already outermost or all dims before it are 1
11251172
if dim == 0 or math.prod(output_shape[:dim]) == 1:
1126-
# Not needed if dim is already outermost or all dims before it are 1.
1127-
return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta)
1128-
1129-
if op == exir_ops.edge.aten.slice_copy.Tensor:
1130-
# Transpose -> slice.
1131-
slice_args = (
1132-
self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0),
1133-
0,
1134-
) + args[2:]
1135-
new_op = super().call_operator(op, slice_args, kwargs, meta)
1136-
else:
1137-
# (Transpose input0, Transpose input1, ...) -> cat.
1138-
cat_in_tensors = [
1139-
self.transpose_dims(t, meta, dim, 0)
1140-
for t in cast(list[ProxyValue], args[0])
1141-
]
1142-
new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta)
1143-
# slice/cat -> transpose.
1144-
return self.transpose_dims(new_op, meta, 0, dim)
1173+
return False
1174+
1175+
graph = node.graph
1176+
1177+
with graph.inserting_before(node):
1178+
if node.target == exir_ops.edge.aten.slice_copy.Tensor:
1179+
# Transpose input -> slice with dim=0 -> transpose back
1180+
input_node = cast(torch.fx.Node, node.args[0])
1181+
transposed_input = self._transpose_dims(graph, input_node, dim, 0)
1182+
1183+
# Create slice operation with dim=0
1184+
slice_args = (transposed_input, 0) + node.args[2:]
1185+
sliced = graph.call_function(
1186+
exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs
1187+
)
1188+
sliced.meta = node.meta
1189+
1190+
# Transpose back
1191+
result = self._transpose_dims(graph, sliced, 0, dim)
1192+
else:
1193+
# Cat operation: transpose all inputs -> cat with dim=0 -> transpose back
1194+
cat_inputs = cast(list[torch.fx.Node], node.args[0])
1195+
transposed_inputs = [
1196+
self._transpose_dims(graph, t, dim, 0) for t in cat_inputs
1197+
]
1198+
1199+
# Create cat operation with dim=0
1200+
catted = graph.call_function(
1201+
exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs
1202+
)
1203+
catted.meta = node.meta
1204+
1205+
# Transpose back
1206+
result = self._transpose_dims(graph, catted, 0, dim)
1207+
1208+
# Replace all uses with the final result
1209+
node.replace_all_uses_with(result)
1210+
return True
11451211

11461212

11471213
@register_cadence_pass(CadencePassAttribute(opt_level=2))

0 commit comments

Comments
 (0)