Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d9385b0

Browse files
authored
Add kwargs support to torch.onnx.export api for dynamic_axes and other args (#336)
* Add dynamic axes support for pytorch onnx exports * Change to make more generic to accept kwargs for the torch.onnx.export api
1 parent ead40a3 commit d9385b0

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

src/sparseml/pytorch/utils/exporter.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def export_onnx(
176176
opset: int = DEFAULT_ONNX_OPSET,
177177
disable_bn_fusing: bool = True,
178178
convert_qat: bool = False,
179+
**export_kwargs,
179180
):
180181
"""
181182
Export an onnx file for the current module and for a sample batch.
@@ -197,7 +198,14 @@ def export_onnx(
197198
the module being exported, the resulting QAT ONNX model will be converted
198199
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
199200
is False.
201+
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
202+
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
203+
See more on the torch.onnx.export api spec in the PyTorch docs:
204+
https://pytorch.org/docs/stable/onnx.html
200205
"""
206+
if not export_kwargs:
207+
export_kwargs = {}
208+
201209
if isinstance(sample_batch, Dict) and not isinstance(
202210
sample_batch, collections.OrderedDict
203211
):
@@ -217,20 +225,26 @@ def export_onnx(
217225
sample_batch, self._module, check_feat_lab_inp=False
218226
)
219227

220-
input_names = None
221-
if isinstance(sample_batch, Tensor):
222-
input_names = ["input"]
223-
elif isinstance(sample_batch, Dict):
224-
input_names = list(sample_batch.keys())
225-
sample_batch = tuple([sample_batch[f] for f in input_names])
226-
elif isinstance(sample_batch, Iterable):
227-
input_names = [
228-
"input_{}".format(index) for index, _ in enumerate(iter(sample_batch))
229-
]
230-
if isinstance(sample_batch, List):
231-
sample_batch = tuple(sample_batch) # torch.onnx.export requires tuple
228+
if "input_names" not in export_kwargs:
229+
if isinstance(sample_batch, Tensor):
230+
export_kwargs["input_names"] = ["input"]
231+
elif isinstance(sample_batch, Dict):
232+
export_kwargs["input_names"] = list(sample_batch.keys())
233+
sample_batch = tuple(
234+
[sample_batch[f] for f in export_kwargs["input_names"]]
235+
)
236+
elif isinstance(sample_batch, Iterable):
237+
export_kwargs["input_names"] = [
238+
"input_{}".format(index)
239+
for index, _ in enumerate(iter(sample_batch))
240+
]
241+
if isinstance(sample_batch, List):
242+
sample_batch = tuple(
243+
sample_batch
244+
) # torch.onnx.export requires tuple
232245

233-
output_names = self.get_output_names(out)
246+
if "output_names" not in export_kwargs:
247+
export_kwargs["output_names"] = self.get_output_names(out)
234248

235249
# disable active quantization observers because they cannot be exported
236250
disabled_observers = []
@@ -259,11 +273,10 @@ def export_onnx(
259273
export_module,
260274
sample_batch,
261275
onnx_path,
262-
input_names=input_names,
263-
output_names=output_names,
264276
strip_doc_string=True,
265277
verbose=False,
266278
opset_version=opset,
279+
**export_kwargs,
267280
)
268281

269282
# re-enable disabled quantization observers

0 commit comments

Comments
 (0)