|
40 | 40 | ) |
41 | 41 | from executorch.exir.dialects._ops import ops as exir_ops |
42 | 42 | from executorch.exir.dialects.edge._ops import EdgeOpOverload |
43 | | -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue |
44 | | -from torch.fx.node import Argument |
| 43 | +from executorch.exir.pass_base import ExportPass, PassResult |
45 | 44 |
|
46 | 45 | # A map to represent ops that: |
47 | 46 | # (a) are functionally equivalent; and |
@@ -1023,131 +1022,197 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: |
1023 | 1022 | return dim |
1024 | 1023 |
|
1025 | 1024 |
|
1026 | | -class ExportPassWithTransposeHelper(ExportPass): |
1027 | | - def transpose_dims( |
1028 | | - self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int |
1029 | | - ) -> ProxyValue: |
1030 | | - """Helper function to transpose dims of a `proxy` with given `meta`.""" |
1031 | | - shape = proxy.data.shape |
| 1025 | +@register_cadence_pass(CadencePassAttribute(opt_level=3)) |
| 1026 | +class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface): |
| 1027 | + """ |
| 1028 | + Replace NCHW convolutions with NHWC (channel-last) convolutions by adding |
| 1029 | + transpose operations before and after the convolution. |
| 1030 | + """ |
| 1031 | + |
| 1032 | + @property |
| 1033 | + def targets(self) -> list[EdgeOpOverload]: |
| 1034 | + return [ |
| 1035 | + exir_ops.edge.cadence.conv1d.default, |
| 1036 | + exir_ops.edge.cadence.conv2d.default, |
| 1037 | + exir_ops.edge.cadence.conv3d.default, |
| 1038 | + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, |
| 1039 | + ] |
| 1040 | + |
| 1041 | + def _transpose_dims( |
| 1042 | + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int |
| 1043 | + ) -> torch.fx.Node: |
| 1044 | + """Helper function to transpose dims of a node.""" |
| 1045 | + shape = node.meta["val"].shape |
1032 | 1046 | dim0, dim1 = ( |
1033 | 1047 | canonicalize_transposed_dim(dim0, shape), |
1034 | 1048 | canonicalize_transposed_dim(dim1, shape), |
1035 | 1049 | ) |
1036 | 1050 | dim0, dim1 = min(dim0, dim1), max(dim0, dim1) |
1037 | | - return super().call_operator( |
1038 | | - exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta |
| 1051 | + transpose_node = graph.call_function( |
| 1052 | + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} |
1039 | 1053 | ) |
1040 | | - |
1041 | | - |
1042 | | -@register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1043 | | -class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): |
1044 | | - def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1045 | | - shape = proxy.to_tensor().shape |
| 1054 | + transpose_node.meta = node.meta |
| 1055 | + return transpose_node |
| 1056 | + |
| 1057 | + def _change_nchw_to_nhwc( |
| 1058 | + self, graph: torch.fx.Graph, node: torch.fx.Node |
| 1059 | + ) -> torch.fx.Node: |
| 1060 | + """Convert NCHW format to NHWC format.""" |
| 1061 | + shape = node.meta["val"].shape |
1046 | 1062 | if len(shape) == 3: |
1047 | | - return self.transpose_dims(proxy, meta, 1, -1) |
| 1063 | + return self._transpose_dims(graph, node, 1, -1) |
1048 | 1064 | indices = list(range(len(shape))) |
1049 | 1065 | permute_indices = [indices[0]] + indices[2:] + [indices[1]] |
1050 | | - return super().call_operator( |
1051 | | - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta |
| 1066 | + permute_node = graph.call_function( |
| 1067 | + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} |
1052 | 1068 | ) |
1053 | | - |
1054 | | - def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: |
1055 | | - shape = proxy.to_tensor().shape |
| 1069 | + permute_node.meta = node.meta |
| 1070 | + return permute_node |
| 1071 | + |
| 1072 | + def _change_nhwc_to_nchw( |
| 1073 | + self, graph: torch.fx.Graph, node: torch.fx.Node |
| 1074 | + ) -> torch.fx.Node: |
| 1075 | + """Convert NHWC format to NCHW format.""" |
| 1076 | + shape = node.meta["val"].shape |
1056 | 1077 | if len(shape) == 3: |
1057 | | - return self.transpose_dims(proxy, meta, 1, -1) |
| 1078 | + return self._transpose_dims(graph, node, 1, -1) |
1058 | 1079 | indices = list(range(len(shape))) |
1059 | 1080 | permute_indices = [indices[0], indices[-1]] + indices[1:-1] |
1060 | | - return super().call_operator( |
1061 | | - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta |
| 1081 | + permute_node = graph.call_function( |
| 1082 | + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} |
1062 | 1083 | ) |
| 1084 | + permute_node.meta = node.meta |
| 1085 | + return permute_node |
1063 | 1086 |
|
1064 | | - def call_operator( |
1065 | | - self, |
1066 | | - op, |
1067 | | - args: tuple[Argument, ...], |
1068 | | - kwargs: dict[str, Argument], |
1069 | | - meta: NodeMetadata, |
1070 | | - ) -> ProxyValue: |
1071 | | - if op not in { |
1072 | | - exir_ops.edge.cadence.conv1d.default, |
1073 | | - exir_ops.edge.cadence.conv2d.default, |
1074 | | - exir_ops.edge.cadence.conv3d.default, |
1075 | | - exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, |
1076 | | - }: |
1077 | | - return super().call_operator(op, args, kwargs, meta) |
1078 | | - |
1079 | | - quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor |
| 1087 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1088 | + assert isinstance(node.target, EdgeOpOverload) |
| 1089 | + quantized_op = node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor |
1080 | 1090 |
|
1081 | | - if not quantized_op and len(args) == 8 and args[-1] is True: |
1082 | | - # Already in NHWC layout. |
1083 | | - return super().call_operator(op, args, kwargs, meta) |
| 1091 | + # Check if already in NHWC layout |
| 1092 | + if not quantized_op and len(node.args) == 8 and node.args[-1] is True: |
| 1093 | + return False |
1084 | 1094 |
|
| 1095 | + # Determine the new op target |
1085 | 1096 | if quantized_op: |
1086 | 1097 | new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor |
1087 | 1098 | else: |
1088 | | - # Determine if 1D or 2D convolution based on op |
1089 | | - new_op = op |
| 1099 | + new_op = node.target |
1090 | 1100 |
|
1091 | | - input_proxy = cast(ProxyValue, args[0]) |
1092 | | - weight_proxy = cast(ProxyValue, args[1]) |
1093 | | - input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) |
1094 | | - weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) |
| 1101 | + graph = node.graph |
1095 | 1102 |
|
1096 | | - # Non-quantized ops still need to set the last optional argument to True. |
1097 | | - channel_last_arg = [] if quantized_op else [True] |
| 1103 | + # Get input and weight nodes |
| 1104 | + input_node = cast(torch.fx.Node, node.args[0]) |
| 1105 | + weight_node = cast(torch.fx.Node, node.args[1]) |
1098 | 1106 |
|
1099 | | - new_args = ( |
1100 | | - # Transposed input/weights. |
1101 | | - (input_proxy, weight_proxy) |
1102 | | - # All other args (bias, quant params, etc) |
1103 | | - + tuple(args[2:]) |
1104 | | - + tuple(channel_last_arg) |
1105 | | - ) |
1106 | | - output_proxy = super().call_operator(new_op, new_args, kwargs, meta) |
1107 | | - nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) |
1108 | | - return nchw_proxy |
| 1107 | + # Insert transpose operations before the node |
| 1108 | + with graph.inserting_before(node): |
| 1109 | + # Convert input from NCHW to NHWC |
| 1110 | + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) |
| 1111 | + # Convert weight from NCHW to NHWC |
| 1112 | + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) |
| 1113 | + |
| 1114 | + # Non-quantized ops need to set the last optional argument to True |
| 1115 | + channel_last_arg = [] if quantized_op else [True] |
| 1116 | + |
| 1117 | + # Create new args with transposed input/weights |
| 1118 | + new_args = ( |
| 1119 | + (input_nhwc, weight_nhwc) |
| 1120 | + + tuple(node.args[2:]) |
| 1121 | + + tuple(channel_last_arg) |
| 1122 | + ) |
| 1123 | + |
| 1124 | + # Create the new conv operation |
| 1125 | + new_conv = graph.call_function(new_op, new_args, node.kwargs) |
| 1126 | + new_conv.meta = node.meta |
| 1127 | + |
| 1128 | + # Convert output back from NHWC to NCHW |
| 1129 | + nchw_output = self._change_nhwc_to_nchw(graph, new_conv) |
| 1130 | + |
| 1131 | + # Replace all uses with the final output |
| 1132 | + node.replace_all_uses_with(nchw_output) |
| 1133 | + return True |
1109 | 1134 |
|
1110 | 1135 |
|
1111 | 1136 | @register_cadence_pass(CadencePassAttribute(opt_level=3)) |
1112 | | -class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): |
1113 | | - def call_operator( |
1114 | | - self, |
1115 | | - op, |
1116 | | - args: tuple[Argument, ...], |
1117 | | - kwargs: dict[str, Argument], |
1118 | | - meta: NodeMetadata, |
1119 | | - ) -> ProxyValue: |
1120 | | - if op not in { |
| 1137 | +class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): |
| 1138 | + """ |
| 1139 | + Make the slice/cat dimension the outermost dimension by adding transpose |
| 1140 | + operations before and after the slice/cat operation. |
| 1141 | + """ |
| 1142 | + |
| 1143 | + @property |
| 1144 | + def targets(self) -> list[EdgeOpOverload]: |
| 1145 | + return [ |
1121 | 1146 | exir_ops.edge.aten.cat.default, |
1122 | 1147 | exir_ops.edge.aten.slice_copy.Tensor, |
1123 | | - }: |
1124 | | - return super().call_operator(op, args, kwargs, meta) |
1125 | | - dim = cast(int, args[1]) if len(args) > 1 else 0 |
1126 | | - output_shape = meta["val"].shape |
| 1148 | + ] |
| 1149 | + |
| 1150 | + def _transpose_dims( |
| 1151 | + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int |
| 1152 | + ) -> torch.fx.Node: |
| 1153 | + """Helper function to transpose dims of a node.""" |
| 1154 | + shape = node.meta["val"].shape |
| 1155 | + dim0, dim1 = ( |
| 1156 | + canonicalize_transposed_dim(dim0, shape), |
| 1157 | + canonicalize_transposed_dim(dim1, shape), |
| 1158 | + ) |
| 1159 | + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) |
| 1160 | + transpose_node = graph.call_function( |
| 1161 | + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} |
| 1162 | + ) |
| 1163 | + transpose_node.meta = node.meta |
| 1164 | + return transpose_node |
| 1165 | + |
| 1166 | + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
| 1167 | + # Get the dimension argument |
| 1168 | + dim = cast(int, node.args[1]) if len(node.args) > 1 else 0 |
| 1169 | + output_shape = node.meta["val"].shape |
| 1170 | + |
| 1171 | + # Canonicalize dim to be positive |
1127 | 1172 | if dim < 0: |
1128 | | - # Keep dim positive. |
1129 | 1173 | dim += len(output_shape) |
1130 | 1174 |
|
| 1175 | + # Not needed if dim is already outermost or all dims before it are 1 |
1131 | 1176 | if dim == 0 or math.prod(output_shape[:dim]) == 1: |
1132 | | - # Not needed if dim is already outermost or all dims before it are 1. |
1133 | | - return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) |
1134 | | - |
1135 | | - if op == exir_ops.edge.aten.slice_copy.Tensor: |
1136 | | - # Transpose -> slice. |
1137 | | - slice_args = ( |
1138 | | - self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), |
1139 | | - 0, |
1140 | | - ) + args[2:] |
1141 | | - new_op = super().call_operator(op, slice_args, kwargs, meta) |
1142 | | - else: |
1143 | | - # (Transpose input0, Transpose input1, ...) -> cat. |
1144 | | - cat_in_tensors = [ |
1145 | | - self.transpose_dims(t, meta, dim, 0) |
1146 | | - for t in cast(list[ProxyValue], args[0]) |
1147 | | - ] |
1148 | | - new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) |
1149 | | - # slice/cat -> transpose. |
1150 | | - return self.transpose_dims(new_op, meta, 0, dim) |
| 1177 | + return False |
| 1178 | + |
| 1179 | + graph = node.graph |
| 1180 | + |
| 1181 | + with graph.inserting_before(node): |
| 1182 | + if node.target == exir_ops.edge.aten.slice_copy.Tensor: |
| 1183 | + # Transpose input -> slice with dim=0 -> transpose back |
| 1184 | + input_node = cast(torch.fx.Node, node.args[0]) |
| 1185 | + transposed_input = self._transpose_dims(graph, input_node, dim, 0) |
| 1186 | + |
| 1187 | + # Create slice operation with dim=0 |
| 1188 | + slice_args = (transposed_input, 0) + node.args[2:] |
| 1189 | + sliced = graph.call_function( |
| 1190 | + exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs |
| 1191 | + ) |
| 1192 | + sliced.meta = node.meta |
| 1193 | + |
| 1194 | + # Transpose back |
| 1195 | + result = self._transpose_dims(graph, sliced, 0, dim) |
| 1196 | + else: |
| 1197 | + # Cat operation: transpose all inputs -> cat with dim=0 -> transpose back |
| 1198 | + cat_inputs = cast(list[torch.fx.Node], node.args[0]) |
| 1199 | + transposed_inputs = [ |
| 1200 | + self._transpose_dims(graph, t, dim, 0) |
| 1201 | + for t in cat_inputs |
| 1202 | + ] |
| 1203 | + |
| 1204 | + # Create cat operation with dim=0 |
| 1205 | + catted = graph.call_function( |
| 1206 | + exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs |
| 1207 | + ) |
| 1208 | + catted.meta = node.meta |
| 1209 | + |
| 1210 | + # Transpose back |
| 1211 | + result = self._transpose_dims(graph, catted, 0, dim) |
| 1212 | + |
| 1213 | + # Replace all uses with the final result |
| 1214 | + node.replace_all_uses_with(result) |
| 1215 | + return True |
1151 | 1216 |
|
1152 | 1217 |
|
1153 | 1218 | @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
|
0 commit comments