Skip to content

Commit 3f6eecf

Browse files
authored
Fix metadata props handling for values (#198)
1. Add metadata_props to value initializers 2. Fix serde to persist the field correctly Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 8862e6e commit 3f6eecf

File tree

4 files changed

+94
-10
lines changed

4 files changed

+94
-10
lines changed

src/onnx_ir/_convenience/_constructors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def val(
224224
*,
225225
type: ir.TypeProtocol | None = None,
226226
const_value: ir.TensorProtocol | None = None,
227+
metadata_props: dict[str, str] | None = None,
227228
) -> ir.Value:
228229
"""Create a :class:`~onnx_ir.Value` with the given name and type.
229230
@@ -253,6 +254,7 @@ def val(
253254
type: The type of the value. Only one of dtype and type can be specified.
254255
const_value: The constant tensor that initializes the value. Supply this argument
255256
when you want to create an initializer. The type and shape can be obtained from the tensor.
257+
metadata_props: The metadata properties that will be serialized to the ONNX proto.
256258
257259
Returns:
258260
A Value object.
@@ -279,10 +281,11 @@ def val(
279281
type=const_tensor_type,
280282
shape=_core.Shape(const_value.shape), # type: ignore
281283
const_value=const_value,
284+
metadata_props=metadata_props,
282285
)
283286

284287
if type is None and dtype is not None:
285288
type = _core.TensorType(dtype)
286289
if shape is not None and not isinstance(shape, _core.Shape):
287290
shape = _core.Shape(shape)
288-
return _core.Value(name=name, type=type, shape=shape)
291+
return _core.Value(name=name, type=type, shape=shape, metadata_props=metadata_props)

src/onnx_ir/_core.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ def nbytes(self) -> int:
165165

166166
@property
167167
def metadata_props(self) -> dict[str, str]:
168+
"""The metadata properties of the tensor.
169+
170+
The metadata properties are used to store additional information about the tensor.
171+
Unlike ``meta``, this property is serialized to the ONNX proto.
172+
"""
168173
if self._metadata_props is None:
169174
self._metadata_props = {}
170175
return self._metadata_props
@@ -2022,6 +2027,7 @@ def __init__(
20222027
type: _protocols.TypeProtocol | None = None,
20232028
doc_string: str | None = None,
20242029
const_value: _protocols.TensorProtocol | None = None,
2030+
metadata_props: dict[str, str] | None = None,
20252031
) -> None:
20262032
"""Initialize a value.
20272033
@@ -2034,11 +2040,12 @@ def __init__(
20342040
type: The type of the value.
20352041
doc_string: The documentation string.
20362042
const_value: The constant tensor if the value is constant.
2043+
metadata_props: Metadata that will be serialized to the ONNX file.
20372044
"""
20382045
self._producer: Node | None = producer
20392046
self._index: int | None = index
20402047
self._metadata: _metadata.MetadataStore | None = None
2041-
self._metadata_props: dict[str, str] | None = None
2048+
self._metadata_props: dict[str, str] | None = metadata_props
20422049

20432050
self._name: str | None = name
20442051
self._shape: Shape | None = shape
@@ -2226,9 +2233,16 @@ def shape(self, value: Shape | None) -> None:
22262233
def const_value(
22272234
self,
22282235
) -> _protocols.TensorProtocol | None:
2229-
"""A concrete value.
2236+
"""The backing constant tensor for the value.
22302237
2231-
The value can be backed by different raw data types, such as numpy arrays.
2238+
If the ``Value`` has a ``const_value`` and is part of a graph initializers
2239+
dictionary, the value is an initialized value. Its ``const_value``
2240+
will appear as an ``initializer`` in the GraphProto when serialized.
2241+
2242+
If the ``Value`` is not part of a graph initializers dictionary, the ``const_value``
2243+
field will be ignored during serialization.
2244+
2245+
``const_value`` can be backed by different raw data types, such as numpy arrays.
22322246
The only guarantee is that it conforms TensorProtocol.
22332247
"""
22342248
return self._const_value
@@ -2258,6 +2272,11 @@ def meta(self) -> _metadata.MetadataStore:
22582272

22592273
@property
22602274
def metadata_props(self) -> dict[str, str]:
2275+
"""The metadata properties of the value.
2276+
2277+
The metadata properties are used to store additional information about the value.
2278+
Unlike ``meta``, this property is serialized to the ONNX proto.
2279+
"""
22612280
if self._metadata_props is None:
22622281
self._metadata_props = {}
22632282
return self._metadata_props
@@ -2805,6 +2824,11 @@ def meta(self) -> _metadata.MetadataStore:
28052824

28062825
@property
28072826
def metadata_props(self) -> dict[str, str]:
2827+
"""The metadata properties of the graph.
2828+
2829+
The metadata properties are used to store additional information about the graph.
2830+
Unlike ``meta``, this property is serialized to the ONNX proto.
2831+
"""
28082832
if self._metadata_props is None:
28092833
self._metadata_props = {}
28102834
return self._metadata_props
@@ -3057,6 +3081,11 @@ def meta(self) -> _metadata.MetadataStore:
30573081

30583082
@property
30593083
def metadata_props(self) -> dict[str, str]:
3084+
"""The metadata properties of the model.
3085+
3086+
The metadata properties are used to store additional information about the model.
3087+
Unlike ``meta``, this property is serialized to the ONNX proto.
3088+
"""
30603089
if self._metadata_props is None:
30613090
self._metadata_props = {}
30623091
return self._metadata_props
@@ -3250,6 +3279,11 @@ def meta(self) -> _metadata.MetadataStore:
32503279

32513280
@property
32523281
def metadata_props(self) -> dict[str, str]:
3282+
"""The metadata properties of the function.
3283+
3284+
The metadata properties are used to store additional information about the function.
3285+
Unlike ``meta``, this property is serialized to the ONNX proto.
3286+
"""
32533287
return self._graph.metadata_props
32543288

32553289
def all_nodes(self) -> Iterator[Node]:

src/onnx_ir/serde.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,7 @@ def _deserialize_graph(
709709
annotation.tensor_name: annotation for annotation in proto.quantization_annotation
710710
}
711711

712-
# Create values for initializers and inputs
713-
initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
712+
# Create values for inputs
714713
inputs = [_core.Value(name=info.name) for info in proto.input]
715714
for info, value in zip(proto.input, inputs):
716715
deserialize_value_info_proto(info, value)
@@ -725,6 +724,11 @@ def _deserialize_graph(
725724
# Enter the graph scope by pushing the values for this scope to the stack
726725
scoped_values.append(values)
727726

727+
# Build the value info dictionary to allow for quick lookup for this graph scope
728+
value_info = {info.name: info for info in proto.value_info}
729+
730+
# Create values for initializers
731+
initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
728732
initializer_values = []
729733
for i, tensor in enumerate(initializer_tensors):
730734
initializer_name = tensor.name
@@ -750,16 +754,15 @@ def _deserialize_graph(
750754
shape=tensor.shape, # type: ignore[arg-type]
751755
const_value=tensor,
752756
)
757+
if initializer_name in value_info:
758+
deserialize_value_info_proto(value_info[initializer_name], initializer_value)
753759
if initializer_value.name in quantization_annotations:
754760
_deserialize_quantization_annotation(
755761
quantization_annotations[initializer_value.name], initializer_value
756762
)
757763
values[initializer_name] = initializer_value
758764
initializer_values.append(initializer_value)
759765

760-
# Build the value info dictionary to allow for quick lookup for this graph scope
761-
value_info = {info.name: info for info in proto.value_info}
762-
763766
# Declare values for all node outputs from this graph scope. This is necessary
764767
# to handle the case where a node in a subgraph uses a value that is declared out
765768
# of order in the outer graph. Declaring the values first allows us to find the
@@ -1390,7 +1393,12 @@ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool
13901393
True if value info should be created for the value.
13911394
"""
13921395
# No need to serialize value info if it is not set
1393-
if value.shape is None and value.type is None:
1396+
if (
1397+
value.shape is None
1398+
and value.type is None
1399+
and not value.metadata_props
1400+
and not value.doc_string
1401+
):
13941402
return False
13951403
if not value.name:
13961404
logger.debug("Did not serialize '%s' because its name is empty", value)

src/onnx_ir/serde_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,45 @@ def test_deserialize_builds_correct_value_connections_for_subgraphs_that_referen
531531
[n.name for n in deserialized_model.graph], ["b_producer", "node_with_subgraph"]
532532
)
533533

534+
def test_value_metadata_props_are_preserved(self):
535+
value = ir.val(
536+
"test_initializer",
537+
dtype=ir.DataType.FLOAT,
538+
shape=(2,),
539+
const_value=ir.tensor([1.0, 2.0], name="test_initializer"),
540+
metadata_props={"key": "value"},
541+
)
542+
input = ir.val(
543+
"test_input", dtype=ir.DataType.FLOAT, shape=(2,), metadata_props={"key": "input"}
544+
)
545+
node = ir.node("Identity", inputs=[input])
546+
node.outputs[0].metadata_props["key"] = "intermediate"
547+
output = ir.val(
548+
"test_output",
549+
dtype=ir.DataType.FLOAT,
550+
shape=(2,),
551+
metadata_props={"key": "output"},
552+
)
553+
node2 = ir.node("Identity", inputs=node.outputs, outputs=[output])
554+
graph = ir.Graph(
555+
inputs=[input],
556+
outputs=[output],
557+
nodes=[node, node2],
558+
initializers=[value],
559+
name="test_graph",
560+
)
561+
graph_proto = serde.serialize_graph(graph)
562+
deserialized_graph = serde.deserialize_graph(graph_proto)
563+
564+
self.assertEqual(deserialized_graph.inputs[0].metadata_props, {"key": "input"})
565+
self.assertEqual(deserialized_graph.outputs[0].metadata_props, {"key": "output"})
566+
intermediate_value = deserialized_graph.node(0).outputs[0]
567+
self.assertEqual(intermediate_value.metadata_props, {"key": "intermediate"})
568+
569+
self.assertIn("test_initializer", deserialized_graph.initializers)
570+
deserialized_value = deserialized_graph.initializers["test_initializer"]
571+
self.assertEqual(deserialized_value.metadata_props, {"key": "value"})
572+
534573

535574
class SerializationTest(unittest.TestCase):
536575
@parameterized.parameterized.expand(

0 commit comments

Comments
 (0)