Skip to content

Commit 59086bc

Browse files
justinchubyCopilot
andauthored
Implement shape merging in identity elimination pass (#206)
This PR implements shape merging functionality in the identity elimination pass to preserve shape information when eliminating redundant Identity nodes. Following microsoft/onnxscript#2588. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8b9ad72 commit 59086bc

File tree

2 files changed

+380
-0
lines changed

2 files changed

+380
-0
lines changed

src/onnx_ir/passes/common/identity_elimination.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,29 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18+
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
19+
def merge_dims(dim1, dim2):
20+
if dim1 == dim2:
21+
return dim1
22+
if not isinstance(dim1, ir.SymbolicDim):
23+
return dim1 # Prefer int value over symbolic dim
24+
if not isinstance(dim2, ir.SymbolicDim):
25+
return dim2
26+
if dim1.value is None:
27+
return dim2
28+
return dim1
29+
30+
if shape1 is None:
31+
return shape2
32+
if shape2 is None:
33+
return shape1
34+
if len(shape1) != len(shape2):
35+
raise ValueError(
36+
f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
37+
)
38+
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
39+
40+
1841
class IdentityEliminationPass(ir.passes.InPlacePass):
1942
"""Pass for eliminating redundant Identity nodes.
2043
@@ -75,6 +98,11 @@ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
7598
if output_is_graph_output and input_is_graph_input:
7699
return False
77100

101+
# Copy over shape/type if the output has more complete information
102+
input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
103+
if input_value.type is None:
104+
input_value.type = output_value.type
105+
78106
# Case 1 & 2 (merged): Eliminate the identity node
79107
# Replace all uses of output with input
80108
ir.convenience.replace_all_uses_with(output_value, input_value)

0 commit comments

Comments
 (0)