Skip to content

Commit 9116df7

Browse files
DrJessopAndrew Grebenisan
andauthored
Fix ReplaceConvolutionOptionalArgsWithConcreteArgsPass (#16143)
Summary: Was missing default variant for transposed_convolution, which resulted in us skipping them as targets. Differential Revision: D88705767 Co-authored-by: Andrew Grebenisan <agrebenisan@meta.com>
1 parent 6cca6e6 commit 9116df7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,14 +498,15 @@ def targets(self) -> list[EdgeOpOverload]:
498498
exir_ops.edge.cadence.conv1d.default,
499499
exir_ops.edge.cadence.conv2d.default,
500500
exir_ops.edge.cadence.conv3d.default,
501-
exir_ops.edge.cadence.transposed_convolution,
501+
exir_ops.edge.cadence.transposed_convolution.default,
502502
]
503503

504504
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
505505
# Check if this is a transposed convolution
506506
assert isinstance(node.target, EdgeOpOverload)
507-
op_packet = get_edge_overload_packet(node.target)
508-
is_transposed = op_packet == exir_ops.edge.cadence.transposed_convolution
507+
is_transposed = (
508+
node.target == exir_ops.edge.cadence.transposed_convolution.default
509+
)
509510
num_expected_args = 9 if is_transposed else 7
510511
assert len(node.args) == num_expected_args
511512
# Check if the bias is concrete
@@ -515,13 +516,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
515516
# The bias length is the number of out channels.
516517
out_shape = node.meta["val"].shape
517518
bias_size = out_shape[1]
518-
# Create a zero bias tensor (bias is not a constant tensor,
519+
520+
# Create a zero bias tensor
519521
with node.graph.inserting_before(node):
520522
zero_bias = node.graph.call_function(
521523
exir_ops.edge.aten.full.default,
522524
args=([bias_size], 0.0),
523525
kwargs={"dtype": torch.float32},
524526
)
527+
# Create proper metadata for the zero_bias node
525528
zero_bias.meta = node.meta
526529
new_args = list(node.args)
527530
new_args[2] = zero_bias

0 commit comments

Comments
 (0)