diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 0b88e687224..fa8e8ca202a 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -370,15 +370,23 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: graph = node.graph + fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0 + fit_mat2 = is_node_with_op(mat2, "get_attr") + # Handle transpose: if mat2 is a transpose op, extract the original tensor transposed_mat2 = False if ( - mat2.op == "call_function" + not fit_mat2 + and mat2.op == "call_function" and mat2.target == exir_ops.edge.aten.transpose_copy.int ): # mat2 is already transposed, so we use the input to the transpose mat2 = cast(torch.fx.Node, mat2.args[0]) transposed_mat2 = True + fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0 + + if not (fit_bias and fit_mat2): + return False # Multiply bias by beta if needed if beta != 1.0: