From 951f8ca9060c8dd635c4ccbbb61358b7ee3ae085 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 11 Dec 2025 18:30:40 -0800 Subject: [PATCH] Update ReplaceConvWithChannelLastConvPass and MakeSliceAndCatDimOutermostPass to correctly set modified bit (#16185) Summary: Update ReplaceConvWithChannelLastConvPass and MakeSliceAndCatDimOutermostPass to use new interface Reviewed By: ethansfng Differential Revision: D87880891 --- backends/cadence/aot/replace_ops.py | 256 +++++++++++------- .../aot/tests/test_replace_ops_passes.py | 86 +++++- 2 files changed, 238 insertions(+), 104 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 8a78cea438c..caa3984d1af 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -41,8 +41,7 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from torch.fx.node import Argument +from executorch.exir.pass_base import ExportPass, PassResult # A map to represent ops that: # (a) are functionally equivalent; and @@ -1017,131 +1016,198 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: return dim -class ExportPassWithTransposeHelper(ExportPass): - def transpose_dims( - self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int - ) -> ProxyValue: - """Helper function to transpose dims of a `proxy` with given `meta`.""" - shape = proxy.data.shape +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface): + """ + Replace NCHW convolutions with NHWC (channel-last) convolutions by adding + transpose operations before and after the convolution. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.conv1d.default, + exir_ops.edge.cadence.conv2d.default, + exir_ops.edge.cadence.conv3d.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + ] + + def _transpose_dims( + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int + ) -> torch.fx.Node: + """Helper function to transpose dims of a node.""" + shape = node.meta["val"].shape dim0, dim1 = ( canonicalize_transposed_dim(dim0, shape), canonicalize_transposed_dim(dim1, shape), ) dim0, dim1 = min(dim0, dim1), max(dim0, dim1) - return super().call_operator( - exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta + transpose_node = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=3)) -class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): - def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: - shape = proxy.to_tensor().shape + transpose_node.meta = node.meta + return transpose_node + + def _change_nchw_to_nhwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NCHW format to NHWC format.""" + shape = node.meta["val"].shape if len(shape) == 3: - return self.transpose_dims(proxy, meta, 1, -1) + return self._transpose_dims(graph, node, 1, -1) indices = list(range(len(shape))) permute_indices = [indices[0]] + indices[2:] + [indices[1]] - return super().call_operator( - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} ) - - def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: - shape = proxy.to_tensor().shape + permute_node.meta = node.meta + return permute_node + + def _change_nhwc_to_nchw( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NHWC format to NCHW format.""" + shape = node.meta["val"].shape if len(shape) == 3: - return self.transpose_dims(proxy, meta, 1, -1) + return self._transpose_dims(graph, node, 1, -1) indices = list(range(len(shape))) permute_indices = [indices[0], indices[-1]] + indices[1:-1] - return super().call_operator( - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} ) + permute_node.meta = node.meta + return permute_node - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.cadence.conv1d.default, - exir_ops.edge.cadence.conv2d.default, - exir_ops.edge.cadence.conv3d.default, - exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, - }: - return super().call_operator(op, args, kwargs, meta) - - quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert isinstance(node.target, EdgeOpOverload) + quantized_op = ( + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + ) - if not quantized_op and len(args) == 8 and args[-1] is True: - # Already in NHWC layout. - return super().call_operator(op, args, kwargs, meta) + # Check if already in NHWC layout + if not quantized_op and len(node.args) == 8 and node.args[-1] is True: + return False + # Determine the new op target if quantized_op: new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor else: - # Determine if 1D or 2D convolution based on op - new_op = op + new_op = node.target + + graph = node.graph - input_proxy = cast(ProxyValue, args[0]) - weight_proxy = cast(ProxyValue, args[1]) - input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) - weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) + # Get input and weight nodes + input_node = cast(torch.fx.Node, node.args[0]) + weight_node = cast(torch.fx.Node, node.args[1]) - # Non-quantized ops still need to set the last optional argument to True. - channel_last_arg = [] if quantized_op else [True] + # Insert transpose operations before the node + with graph.inserting_before(node): + # Convert input from NCHW to NHWC + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) + # Convert weight from NCHW to NHWC + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) - new_args = ( - # Transposed input/weights. - (input_proxy, weight_proxy) - # All other args (bias, quant params, etc) - + tuple(args[2:]) - + tuple(channel_last_arg) - ) - output_proxy = super().call_operator(new_op, new_args, kwargs, meta) - nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) - return nchw_proxy + # Non-quantized ops need to set the last optional argument to True + channel_last_arg = [] if quantized_op else [True] + + # Create new args with transposed input/weights + new_args = ( + (input_nhwc, weight_nhwc) + + tuple(node.args[2:]) + + tuple(channel_last_arg) + ) + + # Create the new conv operation + new_conv = graph.call_function(new_op, new_args, node.kwargs) + new_conv.meta = node.meta + + # Convert output back from NHWC to NCHW + nchw_output = self._change_nhwc_to_nchw(graph, new_conv) + + # Replace all uses with the final output + node.replace_all_uses_with(nchw_output) + return True @register_cadence_pass(CadencePassAttribute(opt_level=3)) -class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { +class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): + """ + Make the slice/cat dimension the outermost dimension by adding transpose + operations before and after the slice/cat operation. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.slice_copy.Tensor, - }: - return super().call_operator(op, args, kwargs, meta) - dim = cast(int, args[1]) if len(args) > 1 else 0 - output_shape = meta["val"].shape + ] + + def _transpose_dims( + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int + ) -> torch.fx.Node: + """Helper function to transpose dims of a node.""" + shape = node.meta["val"].shape + dim0, dim1 = ( + canonicalize_transposed_dim(dim0, shape), + canonicalize_transposed_dim(dim1, shape), + ) + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) + transpose_node = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} + ) + transpose_node.meta = node.meta + return transpose_node + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the dimension argument + dim = cast(int, node.args[1]) if len(node.args) > 1 else 0 + output_shape = node.meta["val"].shape + + # Canonicalize dim to be positive if dim < 0: - # Keep dim positive. dim += len(output_shape) + # Not needed if dim is already outermost or all dims before it are 1 if dim == 0 or math.prod(output_shape[:dim]) == 1: - # Not needed if dim is already outermost or all dims before it are 1. - return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) - - if op == exir_ops.edge.aten.slice_copy.Tensor: - # Transpose -> slice. - slice_args = ( - self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), - 0, - ) + args[2:] - new_op = super().call_operator(op, slice_args, kwargs, meta) - else: - # (Transpose input0, Transpose input1, ...) -> cat. - cat_in_tensors = [ - self.transpose_dims(t, meta, dim, 0) - for t in cast(list[ProxyValue], args[0]) - ] - new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) - # slice/cat -> transpose. - return self.transpose_dims(new_op, meta, 0, dim) + return False + + graph = node.graph + + with graph.inserting_before(node): + if node.target == exir_ops.edge.aten.slice_copy.Tensor: + # Transpose input -> slice with dim=0 -> transpose back + input_node = cast(torch.fx.Node, node.args[0]) + transposed_input = self._transpose_dims(graph, input_node, dim, 0) + + # Create slice operation with dim=0 + slice_args = (transposed_input, 0) + node.args[2:] + sliced = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs + ) + sliced.meta = node.meta + + # Transpose back + result = self._transpose_dims(graph, sliced, 0, dim) + else: + # Cat operation: transpose all inputs -> cat with dim=0 -> transpose back + cat_inputs = cast(list[torch.fx.Node], node.args[0]) + transposed_inputs = [ + self._transpose_dims(graph, t, dim, 0) for t in cat_inputs + ] + + # Create cat operation with dim=0 + catted = graph.call_function( + exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs + ) + catted.meta = node.meta + + # Transpose back + result = self._transpose_dims(graph, catted, 0, dim) + + # Replace all uses with the final result + node.replace_all_uses_with(result) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 91b44925455..cc93a06e93f 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1863,8 +1863,11 @@ def test_quantized_convolution_default_channel_last(self) -> None: # Apply replacement pass. p = ReplaceConvWithChannelLastConvPass() original = copy.deepcopy(gm) - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Check that replacement was made. self.assertEqual( count_node( gm_after_replacement, @@ -1877,9 +1880,11 @@ def test_quantized_convolution_default_channel_last(self) -> None: count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 3, ) + + # Validate numerical accuracy validate( - gm_after_replacement, original, + gm_after_replacement, placeholders, "ReplaceConvWithChannelLastConvPass", ) @@ -1933,14 +1938,23 @@ def create_slice_graph( def test_slice_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node. + x = torch.randn(3, 224, 224) gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. self.assertEqual( @@ -1950,14 +1964,23 @@ def test_slice_no_transpose_if_already_outermost(self) -> None: def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. + x = torch.randn(1, 3, 4, 6) gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1968,14 +1991,23 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: def test_slice_insert_transpose(self) -> None: # Create a graph with a single slice node. + x = torch.randn(1, 3, 4, 6) gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that there are two transpose ops added. self.assertEqual( @@ -1997,14 +2029,26 @@ def create_cat_graph( def test_cat_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 3, 5) + input2 = torch.randn(2, 3, 5) gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -2015,14 +2059,26 @@ def test_cat_no_transpose_if_already_outermost(self) -> None: def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 2, 3, 5) gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -2033,6 +2089,8 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: def test_cat_insert_transpose(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 1, 3, 3) gm = self.create_cat_graph( input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 ) @@ -2040,9 +2098,19 @@ def test_cat_insert_transpose(self) -> None: gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) # Assert that transpose ops were added to make cat on outermost dimension. self.assertEqual(