Skip to content

Commit 2f48964

Browse files
Convenience function get_const_tensor (#45)
Get the constant tensor from a value, if it exists. A constant tensor can be obtained if the value has a ``const_value`` set (as in the case of an initializer) or if the value is produced by a Constant node. This function will not alter the ``const_value`` of the value, but it will propagate the shape and type of the constant tensor to the value if `propagate_shape_type` is set to True. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: codecov-ai[bot] <156709835+codecov-ai[bot]@users.noreply.github.com>
1 parent 26ca89a commit 2f48964

File tree

4 files changed

+215
-2
lines changed

4 files changed

+215
-2
lines changed

docs/api/ir_convenience.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
```{eval-rst}
1010
.. autofunction:: convert_attribute
1111
.. autofunction:: convert_attributes
12+
.. autofunction:: create_value_mapping
13+
.. autofunction:: get_const_tensor
1214
.. autofunction:: replace_all_uses_with
1315
.. autofunction:: replace_nodes_and_values
14-
.. autofunction:: create_value_mapping
1516
```

src/onnx_ir/_convenience/__init__.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
"replace_all_uses_with",
1515
"create_value_mapping",
1616
"replace_nodes_and_values",
17+
"get_const_tensor",
1718
]
1819

20+
import logging
1921
from collections.abc import Mapping, Sequence
2022
from typing import Union
2123

24+
import numpy as np
2225
import onnx # noqa: TID251
2326

2427
from onnx_ir import _core, _enums, _protocols, serde, traversal
@@ -42,6 +45,9 @@
4245
]
4346

4447

48+
logger = logging.getLogger(__name__)
49+
50+
4551
def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
4652
"""Infer the attribute type based on the type of the Python object."""
4753
if isinstance(attr, int):
@@ -389,3 +395,104 @@ def replace_nodes_and_values(
389395
# insert new nodes after the index node
390396
graph_or_function.insert_after(insertion_point, new_nodes)
391397
graph_or_function.remove(old_nodes, safe=True)
398+
399+
400+
def get_const_tensor(
401+
value: _core.Value, propagate_shape_type: bool = False
402+
) -> _protocols.TensorProtocol | None:
403+
"""Get the constant tensor from a value, if it exists.
404+
405+
A constant tensor can be obtained if the value has a ``const_value`` set
406+
(as in the case of an initializer) or if the value is produced by a
407+
Constant node.
408+
409+
This function will not alter the ``const_value`` of the value, but
410+
it will propagate the shape and type of the constant tensor to the value
411+
if `propagate_shape_type` is set to True.
412+
413+
Args:
414+
value: The value to get the constant tensor from.
415+
propagate_shape_type: If True, the shape and type of the value will be
416+
propagated to the Value.
417+
418+
Returns:
419+
The constant tensor if it exists, otherwise None.
420+
421+
Raises:
422+
ValueError: If the Constant node does not have exactly one output or
423+
one attribute.
424+
"""
425+
tensor = None
426+
if value.const_value is not None:
427+
tensor = value.const_value
428+
else:
429+
node = value.producer()
430+
if node is None:
431+
# Potentially a graph input
432+
return None
433+
if node.op_type != "Constant" or node.domain != "":
434+
# Not a Constant node or not in the ONNX domain
435+
return None
436+
if len(node.outputs) != 1:
437+
raise ValueError(
438+
f"Constant node '{node.name}' must have exactly one output, "
439+
f"but has {len(node.outputs)} outputs."
440+
)
441+
if len(node.attributes) != 1:
442+
raise ValueError(
443+
f"Constant node '{node.name}' must have exactly one attribute, "
444+
f"but has {len(node.attributes)} attributes."
445+
)
446+
447+
attr_name, attr_value = next(iter(node.attributes.items()))
448+
449+
if attr_value.is_ref():
450+
# TODO: Make it easier to resolve a reference attribute.
451+
# For now we just return None
452+
return None
453+
454+
ir_value = node.outputs[0]
455+
if attr_name in {"value_float", "value_floats"}:
456+
tensor = _core.Tensor(
457+
np.array(attr_value.value, dtype=np.float32), name=ir_value.name
458+
)
459+
elif attr_name in {"value_int", "value_ints"}:
460+
tensor = _core.Tensor(
461+
np.array(attr_value.value, dtype=np.int64), name=ir_value.name
462+
)
463+
elif attr_name in {"value_string", "value_strings"}:
464+
tensor = _core.StringTensor(
465+
np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
466+
)
467+
elif attr_name == "value":
468+
tensor = attr_value.as_tensor()
469+
else:
470+
raise ValueError(
471+
f"Unsupported attribute '{attr_name}' in Constant node '{node.name}'. "
472+
"Expected one of 'value_float', 'value_floats', 'value_int', "
473+
"'value_ints', 'value_string', 'value_strings', or 'value'."
474+
)
475+
# Assign the name of the constant value to the tensor
476+
tensor.name = value.name
477+
if tensor is not None and propagate_shape_type:
478+
# Propagate the shape and type of the tensor to the value
479+
if value.shape is not None and value.shape != tensor.shape:
480+
logger.warning(
481+
"Value '%s' has a shape %s that differs from "
482+
"the constant tensor's shape %s. The value's shape will be updated.",
483+
value,
484+
value.shape,
485+
tensor.shape,
486+
)
487+
value.shape = tensor.shape # type: ignore[assignment]
488+
new_value_type = _core.TensorType(tensor.dtype)
489+
if value.type is not None and value.type != new_value_type:
490+
logger.warning(
491+
"Value '%s' has a type '%s' that differs from "
492+
"the constant tensor's type '%s'. The value's type will be updated.",
493+
value,
494+
value.type,
495+
new_value_type,
496+
)
497+
value.type = new_value_type
498+
return tensor
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) ONNX Project Contributors
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import unittest
6+
7+
import numpy as np
8+
import parameterized
9+
10+
import onnx_ir as ir
11+
from onnx_ir import _convenience
12+
13+
14+
class GetConstantTensorTest(unittest.TestCase):
15+
def test_direct_const_value(self):
16+
# Test when value has a direct const_value
17+
tensor = ir.Tensor(np.array([1, 2, 3], dtype=np.int64), name="test_tensor")
18+
value = ir.Value(name="test_value", type=ir.TensorType(ir.DataType.INT64))
19+
value.const_value = tensor
20+
self.assertIs(_convenience.get_const_tensor(value), tensor)
21+
22+
def test_no_const_value(self):
23+
value = ir.Value(name="test_value", type=ir.TensorType(ir.DataType.FLOAT))
24+
25+
self.assertIsNone(_convenience.get_const_tensor(value))
26+
27+
def test_non_constant_producer_node(self):
28+
# Test when producer node is not a Constant
29+
node = ir.Node(
30+
name="test_node",
31+
domain="",
32+
op_type="Add",
33+
inputs=[],
34+
)
35+
36+
output_value = node.outputs[0]
37+
self.assertIsNone(_convenience.get_const_tensor(output_value))
38+
39+
@parameterized.parameterized.expand(
40+
[
41+
(
42+
"value_float",
43+
ir.AttrFloat32("value_float", 3.14),
44+
np.array(3.14, dtype=np.float32),
45+
),
46+
("value_int", ir.AttrInt64("value_int", 42), np.array(42, dtype=np.int64)),
47+
(
48+
"value_string",
49+
ir.AttrString("value_string", "test"),
50+
np.array(b"test", dtype=object),
51+
),
52+
(
53+
"value_floats",
54+
ir.AttrFloat32s("value_floats", [1.0, 2.0, 3.0]),
55+
np.array([1.0, 2.0, 3.0], dtype=np.float32),
56+
),
57+
(
58+
"value_ints",
59+
ir.AttrInt64s("value_ints", [1, 2, 3]),
60+
np.array([1, 2, 3], dtype=np.int64),
61+
),
62+
(
63+
"value_strings",
64+
ir.AttrStrings("value_strings", ["a", "b", "c"]),
65+
np.array([b"a", b"b", b"c"], dtype=object),
66+
),
67+
(
68+
"value",
69+
ir.AttrTensor("value", ir.tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32))),
70+
np.array([1.0, 2.0, 3.0], dtype=np.float32),
71+
),
72+
]
73+
)
74+
def test_constant_value(self, _: str, attr: ir.Attr, expected: np.ndarray):
75+
# Test with Constant node with float value
76+
node = ir.Node(
77+
name="constant_node",
78+
domain="",
79+
op_type="Constant",
80+
inputs=[],
81+
attributes=(attr,),
82+
)
83+
node.outputs[0].name = "output"
84+
85+
result = _convenience.get_const_tensor(node.outputs[0])
86+
87+
self.assertIsNotNone(result)
88+
self.assertEqual(result.name, "output")
89+
np.testing.assert_array_equal(result.numpy(), expected)
90+
91+
self.assertIsNone(node.outputs[0].shape)
92+
self.assertIsNone(node.outputs[0].type)
93+
94+
result_2 = _convenience.get_const_tensor(node.outputs[0], propagate_shape_type=True)
95+
self.assertIsNotNone(result_2)
96+
self.assertEqual(result_2.name, "output")
97+
np.testing.assert_array_equal(result_2.numpy(), expected)
98+
self.assertEqual(node.outputs[0].shape, expected.shape)
99+
self.assertEqual(node.outputs[0].type, ir.TensorType(result_2.dtype))
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

src/onnx_ir/convenience.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
__all__ = [
88
"convert_attribute",
99
"convert_attributes",
10+
"create_value_mapping",
11+
"get_const_tensor",
1012
"replace_all_uses_with",
1113
"replace_nodes_and_values",
12-
"create_value_mapping",
1314
]
1415

1516
from onnx_ir._convenience import (
1617
convert_attribute,
1718
convert_attributes,
1819
create_value_mapping,
20+
get_const_tensor,
1921
replace_all_uses_with,
2022
replace_nodes_and_values,
2123
)

0 commit comments

Comments
 (0)