|
41 | 41 | ) |
42 | 42 | from executorch.exir.dialects._ops import ops as exir_ops |
43 | 43 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
44 | | -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue |
45 | | -from torch.fx.node import Argument |
| 44 | +from executorch.exir.pass_base import ExportPass, PassResult |
46 | 45 |
|
47 | 46 | # A map to represent ops that: |
48 | 47 | # (a) are functionally equivalent; and |
@@ -1017,131 +1016,198 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: |
1017 | 1016 | return dim |
1018 | 1017 |
|
1019 | 1018 |
|
1020 | | -class ExportPassWithTransposeHelper(ExportPass): |
1021 | | - def transpose_dims( |
1022 | | - self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int |
1023 | | - ) -> ProxyValue: |
1024 | | - """Helper function to transpose dims of a `proxy` with given `meta`.""" |
1025 | | - shape = proxy.data.shape |
| 1019 | +@register_cadence_pass(CadencePassAttribute(opt_level=3)) |
| 1020 | +class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface): |
| 1021 | + """ |
| 1022 | + Replace NCHW convolutions with NHWC (channel-last) convolutions by adding |
| 1023 | + transpose operations before and after the convolution. |
| 1024 | + """ |
| 1025 | + |
| 1026 | + @property |
| 1027 | + def targets(self) -> list[EdgeOpOverload]: |
| 1028 | + return [ |
| 1029 | + exir_ops.edge.cadence.conv1d.default, |
| 1030 | + exir_ops.edge.cadence.conv2d.default, |
| 1031 | + exir_ops.edge.cadence.conv3d.default, |
| 1032 | + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, |
| 1033 | + ] |
| 1034 | + |
| 1035 | + def _transpose_dims( |
| 1036 | + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int |
| 1037 | + ) -> torch.fx.Node: |
| 1038 | + """Helper function to transpose dims of a node.""" |
| 1039 | + shape = node.meta["val"].shape |
1026 | 1040 | dim0, dim1 = ( |
1027 | 1041 | canonicalize_transposed_dim(dim0, shape), |
1028 | 1042 | canonicalize_transposed_dim(dim1, shape), |
1029 | 1043 | ) |
1030 | 1044 | dim0, dim1 = min(dim0, dim1), max(dim0, dim1) |
1031 | | - return super().call_operator( |
1032 | | - exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta |
| 1045 | + transpose_node = graph.call_function( |
| 1046 | + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} |
1033 | 1047 | ) |
1034 | | - |
1035 | | - |
1036 | | -@register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1037 | | -class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): |
1038 | | - def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1039 | | - shape = proxy.to_tensor().shape |
| 1048 | + transpose_node.meta = node.meta |
| 1049 | + return transpose_node |
| 1050 | + |
| 1051 | + def _change_nchw_to_nhwc( |
| 1052 | + self, graph: torch.fx.Graph, node: torch.fx.Node |
| 1053 | + ) -> torch.fx.Node: |
| 1054 | + """Convert NCHW format to NHWC format.""" |
| 1055 | + shape = node.meta["val"].shape |
1040 | 1056 | if len(shape) == 3: |
1041 | | - return self.transpose_dims(proxy, meta, 1, -1) |
| 1057 | + return self._transpose_dims(graph, node, 1, -1) |
1042 | 1058 | indices = list(range(len(shape))) |
1043 | 1059 | permute_indices = [indices[0]] + indices[2:] + [indices[1]] |
1044 | | - return super().call_operator( |
1045 | | - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta |
| 1060 | + permute_node = graph.call_function( |
| 1061 | + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} |
1046 | 1062 | ) |
1047 | | - |
1048 | | - def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1049 | | - shape = proxy.to_tensor().shape |
| 1063 | + permute_node.meta = node.meta |
| 1064 | + return permute_node |
| 1065 | + |
| 1066 | + def _change_nhwc_to_nchw( |
| 1067 | + self, graph: torch.fx.Graph, node: torch.fx.Node |
| 1068 | + ) -> torch.fx.Node: |
| 1069 | + """Convert NHWC format to NCHW format.""" |
| 1070 | + shape = node.meta["val"].shape |
1050 | 1071 | if len(shape) == 3: |
1051 | | - return self.transpose_dims(proxy, meta, 1, -1) |
| 1072 | + return self._transpose_dims(graph, node, 1, -1) |
1052 | 1073 | indices = list(range(len(shape))) |
1053 | 1074 | permute_indices = [indices[0], indices[-1]] + indices[1:-1] |
1054 | | - return super().call_operator( |
1055 | | - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta |
| 1075 | + permute_node = graph.call_function( |
| 1076 | + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} |
1056 | 1077 | ) |
| 1078 | + permute_node.meta = node.meta |
| 1079 | + return permute_node |
1057 | 1080 |
|
1058 | | - def call_operator( |
1059 | | - self, |
1060 | | - op, |
1061 | | - args: tuple[Argument, ...], |
1062 | | - kwargs: dict[str, Argument], |
1063 | | - meta: NodeMetadata, |
1064 | | - ) -> ProxyValue: |
1065 | | - if op not in { |
1066 | | - exir_ops.edge.cadence.conv1d.default, |
1067 | | - exir_ops.edge.cadence.conv2d.default, |
1068 | | - exir_ops.edge.cadence.conv3d.default, |
1069 | | - exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, |
1070 | | - }: |
1071 | | - return super().call_operator(op, args, kwargs, meta) |
1072 | | - |
1073 | | - quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor |
| 1081 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1082 | + assert isinstance(node.target, EdgeOpOverload) |
| 1083 | + quantized_op = ( |
| 1084 | + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor |
| 1085 | + ) |
1074 | 1086 |
|
1075 | | - if not quantized_op and len(args) == 8 and args[-1] is True: |
1076 | | - # Already in NHWC layout. |
1077 | | - return super().call_operator(op, args, kwargs, meta) |
| 1087 | + # Check if already in NHWC layout |
| 1088 | + if not quantized_op and len(node.args) == 8 and node.args[-1] is True: |
| 1089 | + return False |
1078 | 1090 |
|
| 1091 | + # Determine the new op target |
1079 | 1092 | if quantized_op: |
1080 | 1093 | new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor |
1081 | 1094 | else: |
1082 | | - # Determine if 1D or 2D convolution based on op |
1083 | | - new_op = op |
| 1095 | + new_op = node.target |
| 1096 | + |
| 1097 | + graph = node.graph |
1084 | 1098 |
|
1085 | | - input_proxy = cast(ProxyValue, args[0]) |
1086 | | - weight_proxy = cast(ProxyValue, args[1]) |
1087 | | - input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) |
1088 | | - weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) |
| 1099 | + # Get input and weight nodes |
| 1100 | + input_node = cast(torch.fx.Node, node.args[0]) |
| 1101 | + weight_node = cast(torch.fx.Node, node.args[1]) |
1089 | 1102 |
|
1090 | | - # Non-quantized ops still need to set the last optional argument to True. |
1091 | | - channel_last_arg = [] if quantized_op else [True] |
| 1103 | + # Insert transpose operations before the node |
| 1104 | + with graph.inserting_before(node): |
| 1105 | + # Convert input from NCHW to NHWC |
| 1106 | + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) |
| 1107 | + # Convert weight from NCHW to NHWC |
| 1108 | + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) |
1092 | 1109 |
|
1093 | | - new_args = ( |
1094 | | - # Transposed input/weights. |
1095 | | - (input_proxy, weight_proxy) |
1096 | | - # All other args (bias, quant params, etc) |
1097 | | - + tuple(args[2:]) |
1098 | | - + tuple(channel_last_arg) |
1099 | | - ) |
1100 | | - output_proxy = super().call_operator(new_op, new_args, kwargs, meta) |
1101 | | - nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) |
1102 | | - return nchw_proxy |
| 1110 | + # Non-quantized ops need to set the last optional argument to True |
| 1111 | + channel_last_arg = [] if quantized_op else [True] |
| 1112 | + |
| 1113 | + # Create new args with transposed input/weights |
| 1114 | + new_args = ( |
| 1115 | + (input_nhwc, weight_nhwc) |
| 1116 | + + tuple(node.args[2:]) |
| 1117 | + + tuple(channel_last_arg) |
| 1118 | + ) |
| 1119 | + |
| 1120 | + # Create the new conv operation |
| 1121 | + new_conv = graph.call_function(new_op, new_args, node.kwargs) |
| 1122 | + new_conv.meta = node.meta |
| 1123 | + |
| 1124 | + # Convert output back from NHWC to NCHW |
| 1125 | + nchw_output = self._change_nhwc_to_nchw(graph, new_conv) |
| 1126 | + |
| 1127 | + # Replace all uses with the final output |
| 1128 | + node.replace_all_uses_with(nchw_output) |
| 1129 | + return True |
1103 | 1130 |
|
1104 | 1131 |
|
1105 | 1132 | @register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1106 | | -class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): |
1107 | | - def call_operator( |
1108 | | - self, |
1109 | | - op, |
1110 | | - args: tuple[Argument, ...], |
1111 | | - kwargs: dict[str, Argument], |
1112 | | - meta: NodeMetadata, |
1113 | | - ) -> ProxyValue: |
1114 | | - if op not in { |
| 1133 | +class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): |
| 1134 | + """ |
| 1135 | + Make the slice/cat dimension the outermost dimension by adding transpose |
| 1136 | + operations before and after the slice/cat operation. |
| 1137 | + """ |
| 1138 | + |
| 1139 | + @property |
| 1140 | + def targets(self) -> list[EdgeOpOverload]: |
| 1141 | + return [ |
1115 | 1142 | exir_ops.edge.aten.cat.default, |
1116 | 1143 | exir_ops.edge.aten.slice_copy.Tensor, |
1117 | | - }: |
1118 | | - return super().call_operator(op, args, kwargs, meta) |
1119 | | - dim = cast(int, args[1]) if len(args) > 1 else 0 |
1120 | | - output_shape = meta["val"].shape |
| 1144 | + ] |
| 1145 | + |
| 1146 | + def _transpose_dims( |
| 1147 | + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int |
| 1148 | + ) -> torch.fx.Node: |
| 1149 | + """Helper function to transpose dims of a node.""" |
| 1150 | + shape = node.meta["val"].shape |
| 1151 | + dim0, dim1 = ( |
| 1152 | + canonicalize_transposed_dim(dim0, shape), |
| 1153 | + canonicalize_transposed_dim(dim1, shape), |
| 1154 | + ) |
| 1155 | + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) |
| 1156 | + transpose_node = graph.call_function( |
| 1157 | + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} |
| 1158 | + ) |
| 1159 | + transpose_node.meta = node.meta |
| 1160 | + return transpose_node |
| 1161 | + |
| 1162 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1163 | + # Get the dimension argument |
| 1164 | + dim = cast(int, node.args[1]) if len(node.args) > 1 else 0 |
| 1165 | + output_shape = node.meta["val"].shape |
| 1166 | + |
| 1167 | + # Canonicalize dim to be positive |
1121 | 1168 | if dim < 0: |
1122 | | - # Keep dim positive. |
1123 | 1169 | dim += len(output_shape) |
1124 | 1170 |
|
| 1171 | + # Not needed if dim is already outermost or all dims before it are 1 |
1125 | 1172 | if dim == 0 or math.prod(output_shape[:dim]) == 1: |
1126 | | - # Not needed if dim is already outermost or all dims before it are 1. |
1127 | | - return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) |
1128 | | - |
1129 | | - if op == exir_ops.edge.aten.slice_copy.Tensor: |
1130 | | - # Transpose -> slice. |
1131 | | - slice_args = ( |
1132 | | - self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), |
1133 | | - 0, |
1134 | | - ) + args[2:] |
1135 | | - new_op = super().call_operator(op, slice_args, kwargs, meta) |
1136 | | - else: |
1137 | | - # (Transpose input0, Transpose input1, ...) -> cat. |
1138 | | - cat_in_tensors = [ |
1139 | | - self.transpose_dims(t, meta, dim, 0) |
1140 | | - for t in cast(list[ProxyValue], args[0]) |
1141 | | - ] |
1142 | | - new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) |
1143 | | - # slice/cat -> transpose. |
1144 | | - return self.transpose_dims(new_op, meta, 0, dim) |
| 1173 | + return False |
| 1174 | + |
| 1175 | + graph = node.graph |
| 1176 | + |
| 1177 | + with graph.inserting_before(node): |
| 1178 | + if node.target == exir_ops.edge.aten.slice_copy.Tensor: |
| 1179 | + # Transpose input -> slice with dim=0 -> transpose back |
| 1180 | + input_node = cast(torch.fx.Node, node.args[0]) |
| 1181 | + transposed_input = self._transpose_dims(graph, input_node, dim, 0) |
| 1182 | + |
| 1183 | + # Create slice operation with dim=0 |
| 1184 | + slice_args = (transposed_input, 0) + node.args[2:] |
| 1185 | + sliced = graph.call_function( |
| 1186 | + exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs |
| 1187 | + ) |
| 1188 | + sliced.meta = node.meta |
| 1189 | + |
| 1190 | + # Transpose back |
| 1191 | + result = self._transpose_dims(graph, sliced, 0, dim) |
| 1192 | + else: |
| 1193 | + # Cat operation: transpose all inputs -> cat with dim=0 -> transpose back |
| 1194 | + cat_inputs = cast(list[torch.fx.Node], node.args[0]) |
| 1195 | + transposed_inputs = [ |
| 1196 | + self._transpose_dims(graph, t, dim, 0) for t in cat_inputs |
| 1197 | + ] |
| 1198 | + |
| 1199 | + # Create cat operation with dim=0 |
| 1200 | + catted = graph.call_function( |
| 1201 | + exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs |
| 1202 | + ) |
| 1203 | + catted.meta = node.meta |
| 1204 | + |
| 1205 | + # Transpose back |
| 1206 | + result = self._transpose_dims(graph, catted, 0, dim) |
| 1207 | + |
| 1208 | + # Replace all uses with the final result |
| 1209 | + node.replace_all_uses_with(result) |
| 1210 | + return True |
1145 | 1211 |
|
1146 | 1212 |
|
1147 | 1213 | @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
|
0 commit comments