Skip to content

Commit c877068

Browse files
committed
Refactor ONNX backend dispatch and improve test coverage
- Clean up dispatch implementations for shape, subtensor, and tensor_basic ops - Improve property-based testing strategies - Fix type annotations and code style issues - Update test fixtures and assertions
1 parent 34b0239 commit c877068

28 files changed

+950
-940
lines changed

pytensor/link/onnx/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
from pytensor.link.onnx.export import compile_onnx, export_function_onnx, export_onnx
99
from pytensor.link.onnx.linker import ONNXLinker
1010

11+
1112
# ONNX opset version used by default
1213
ONNX_OPSET_VERSION = 18
1314

1415
__all__ = [
16+
"ONNX_OPSET_VERSION",
1517
"ONNXLinker",
16-
"onnx_funcify",
17-
"onnx_typify",
18-
"export_onnx",
1918
"compile_onnx",
2019
"export_function_onnx",
21-
"ONNX_OPSET_VERSION",
20+
"export_onnx",
21+
"onnx_funcify",
22+
"onnx_typify",
2223
]

pytensor/link/onnx/dispatch/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pytensor.link.onnx.dispatch.basic import onnx_funcify, onnx_typify
55

66
# Load dispatch specializations
7-
import pytensor.link.onnx.dispatch.elemwise # noqa: F401
8-
import pytensor.link.onnx.dispatch.shape # noqa: F401
9-
import pytensor.link.onnx.dispatch.math # noqa: F401
10-
import pytensor.link.onnx.dispatch.tensor_basic # noqa: F401
11-
import pytensor.link.onnx.dispatch.subtensor # noqa: F401
12-
import pytensor.link.onnx.dispatch.nlinalg # noqa: F401
13-
import pytensor.link.onnx.dispatch.nnet # noqa: F401
7+
import pytensor.link.onnx.dispatch.elemwise
8+
import pytensor.link.onnx.dispatch.shape
9+
import pytensor.link.onnx.dispatch.math
10+
import pytensor.link.onnx.dispatch.tensor_basic
11+
import pytensor.link.onnx.dispatch.subtensor
12+
import pytensor.link.onnx.dispatch.nlinalg
13+
import pytensor.link.onnx.dispatch.nnet
1414

1515
# isort: on

pytensor/link/onnx/dispatch/basic.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def get_var_name(var):
198198
# Collect all nodes in topological order
199199
nodes = []
200200
initializers = []
201-
value_infos = []
202201

203202
# Process constants first
204203
for var in fgraph.variables:
@@ -213,7 +212,7 @@ def get_var_name(var):
213212
# For now, we'll upcast all scalar integer constants to float32
214213
# This is a simplification but handles the common case of: x * 2
215214
# where x is float and 2 is an int scalar
216-
data = data.astype('float32')
215+
data = data.astype("float32")
217216

218217
tensor_proto = onnx_typify(data, name=name)
219218
initializers.append(tensor_proto)
@@ -235,9 +234,7 @@ def get_var_name(var):
235234
# Multiple nodes - add all to graph
236235
# Used for operations that compile to multiple ONNX ops
237236
# Example: Shape_i returns [Constant, Shape, Gather]
238-
for item in result:
239-
if item is not None:
240-
nodes.append(item)
237+
nodes.extend(item for item in result if item is not None)
241238
elif isinstance(result, tuple):
242239
# Returned (node, additional_initializers)
243240
# Used for operations with constant initializers

pytensor/link/onnx/dispatch/elemwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor.scalar import math as scalar_math
88
from pytensor.tensor.elemwise import Elemwise
99

10+
1011
# ⭐ THE MAGIC MAPPING - Tier 1 + Tier 4-5 operations
1112
SCALAR_OP_TO_ONNX = {
1213
# Arithmetic (Tier 1)

pytensor/link/onnx/dispatch/math.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
"""ONNX conversion for math operations (reductions)."""
22

33
from pytensor.link.onnx.dispatch.basic import onnx_funcify
4-
from pytensor.tensor.math import CAReduce, Argmax
5-
from pytensor.scalar.basic import Add, Mul, Maximum, Minimum, AND, OR
4+
from pytensor.scalar.basic import AND, OR, Add, Maximum, Minimum, Mul
5+
from pytensor.tensor.math import Argmax, CAReduce
6+
67

78
try:
89
from onnx import helper
9-
import numpy as np
1010
except ImportError as e:
1111
raise ImportError("ONNX package required for export") from e
1212

1313

1414
# Mapping from PyTensor scalar ops to ONNX reduction ops
1515
REDUCE_OP_MAP = {
16-
Add: 'ReduceSum',
17-
Mul: 'ReduceProd',
18-
Maximum: 'ReduceMax',
19-
Minimum: 'ReduceMin',
20-
AND: 'ReduceMin', # For boolean AND
21-
OR: 'ReduceMax', # For boolean OR
16+
Add: "ReduceSum",
17+
Mul: "ReduceProd",
18+
Maximum: "ReduceMax",
19+
Minimum: "ReduceMin",
20+
AND: "ReduceMin", # For boolean AND
21+
OR: "ReduceMax", # For boolean OR
2222
}
2323

2424

@@ -57,7 +57,7 @@ def onnx_funcify_CAReduce(op, node, get_var_name, **kwargs):
5757
# For opset 18+, axes must be an input tensor
5858
axes_name = f"{output_name}_axes"
5959
axes_constant = helper.make_node(
60-
'Constant',
60+
"Constant",
6161
inputs=[],
6262
outputs=[axes_name],
6363
name=f"Constant_{axes_name}",
@@ -66,7 +66,7 @@ def onnx_funcify_CAReduce(op, node, get_var_name, **kwargs):
6666
data_type=helper.TensorProto.INT64,
6767
dims=[len(axes_list)],
6868
vals=axes_list,
69-
)
69+
),
7070
)
7171
nodes.append(axes_constant)
7272

@@ -102,15 +102,15 @@ def onnx_funcify_Argmax(op, node, get_var_name, **kwargs):
102102
# Argmax over all axes - need to flatten first
103103
flatten_name = f"{output_name}_flat"
104104
flatten_node = helper.make_node(
105-
'Flatten',
105+
"Flatten",
106106
inputs=[input_name],
107107
outputs=[flatten_name],
108108
name=f"Flatten_{flatten_name}",
109109
axis=0,
110110
)
111111

112112
argmax_node = helper.make_node(
113-
'ArgMax',
113+
"ArgMax",
114114
inputs=[flatten_name],
115115
outputs=[output_name],
116116
name=f"ArgMax_{output_name}",
@@ -130,7 +130,7 @@ def onnx_funcify_Argmax(op, node, get_var_name, **kwargs):
130130
axis = axis[0]
131131

132132
onnx_node = helper.make_node(
133-
'ArgMax',
133+
"ArgMax",
134134
inputs=[input_name],
135135
outputs=[output_name],
136136
name=f"ArgMax_{output_name}",
@@ -139,5 +139,3 @@ def onnx_funcify_Argmax(op, node, get_var_name, **kwargs):
139139
)
140140

141141
return onnx_node
142-
143-

pytensor/link/onnx/dispatch/shape.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import numpy as np
44
from onnx import helper, numpy_helper
55

6-
from pytensor.link.onnx.dispatch.basic import onnx_funcify
7-
from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape, Reshape
8-
from pytensor.tensor.basic import Join, Split
96
from pytensor.graph.basic import Constant
7+
from pytensor.link.onnx.dispatch.basic import onnx_funcify
8+
from pytensor.tensor.basic import Join, Split, get_scalar_constant_value
109
from pytensor.tensor.exceptions import NotScalarConstantError
11-
from pytensor.tensor.basic import get_scalar_constant_value
10+
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
1211

1312

1413
@onnx_funcify.register(type(None))
@@ -27,7 +26,7 @@ def onnx_funcify_Shape(op, node, get_var_name, **kwargs):
2726
output_name = get_var_name(node.outputs[0])
2827

2928
onnx_node = helper.make_node(
30-
'Shape',
29+
"Shape",
3130
inputs=[input_name],
3231
outputs=[output_name],
3332
name=f"Shape_{output_name}",
@@ -69,7 +68,7 @@ def onnx_funcify_Shape_i(op, node, get_var_name, **kwargs):
6968

7069
# Node 1: Create constant for index
7170
idx_constant = helper.make_node(
72-
'Constant',
71+
"Constant",
7372
inputs=[],
7473
outputs=[idx_name],
7574
name=f"Constant_{idx_name}",
@@ -78,20 +77,20 @@ def onnx_funcify_Shape_i(op, node, get_var_name, **kwargs):
7877
data_type=helper.TensorProto.INT64,
7978
dims=[],
8079
vals=[axis_idx],
81-
)
80+
),
8281
)
8382

8483
# Node 2: Get full shape
8584
shape_node = helper.make_node(
86-
'Shape',
85+
"Shape",
8786
inputs=[input_name],
8887
outputs=[shape_name],
8988
name=f"Shape_{shape_name}",
9089
)
9190

9291
# Node 3: Gather specific dimension
9392
gather_node = helper.make_node(
94-
'Gather',
93+
"Gather",
9594
inputs=[shape_name, idx_name],
9695
outputs=[output_name],
9796
name=f"Gather_{output_name}",
@@ -237,7 +236,7 @@ def onnx_funcify_Reshape(op, node, get_var_name, **kwargs):
237236
shape_name = f"{output_name}_shape"
238237

239238
shape_constant = helper.make_node(
240-
'Constant',
239+
"Constant",
241240
inputs=[],
242241
outputs=[shape_name],
243242
name=f"Constant_{shape_name}",
@@ -246,11 +245,11 @@ def onnx_funcify_Reshape(op, node, get_var_name, **kwargs):
246245
data_type=helper.TensorProto.INT64,
247246
dims=[len(shape_data)],
248247
vals=shape_data.tolist(),
249-
)
248+
),
250249
)
251250

252251
reshape_node = helper.make_node(
253-
'Reshape',
252+
"Reshape",
254253
inputs=[data_name, shape_name],
255254
outputs=[output_name],
256255
name=f"Reshape_{output_name}",
@@ -262,7 +261,7 @@ def onnx_funcify_Reshape(op, node, get_var_name, **kwargs):
262261
shape_name = get_var_name(shape_input)
263262

264263
reshape_node = helper.make_node(
265-
'Reshape',
264+
"Reshape",
266265
inputs=[data_name, shape_name],
267266
outputs=[output_name],
268267
name=f"Reshape_{output_name}",
@@ -301,7 +300,7 @@ def onnx_funcify_Join(op, node, get_var_name, **kwargs):
301300

302301
# Create ONNX Concat node
303302
concat_node = helper.make_node(
304-
'Concat',
303+
"Concat",
305304
inputs=input_names,
306305
outputs=[output_name],
307306
name=f"Concat_{output_name}",
@@ -359,7 +358,7 @@ def onnx_funcify_Split(op, node, get_var_name, **kwargs):
359358
# Create constant node for split sizes (required in opset 13+)
360359
split_name = f"{output_names[0]}_split"
361360
split_constant = helper.make_node(
362-
'Constant',
361+
"Constant",
363362
inputs=[],
364363
outputs=[split_name],
365364
name=f"Constant_{split_name}",
@@ -368,12 +367,12 @@ def onnx_funcify_Split(op, node, get_var_name, **kwargs):
368367
data_type=helper.TensorProto.INT64,
369368
dims=[len(splits)],
370369
vals=splits.tolist(),
371-
)
370+
),
372371
)
373372

374373
# Create ONNX Split node with split as an input
375374
split_node = helper.make_node(
376-
'Split',
375+
"Split",
377376
inputs=[input_name, split_name],
378377
outputs=output_names,
379378
name=f"Split_{output_names[0]}",

0 commit comments

Comments
 (0)