Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 07379a6

Browse files
authored
standalone onnx export function (#389)
1 parent 6a47673 commit 07379a6

File tree

1 file changed

+172
-106
lines changed

1 file changed

+172
-106
lines changed

src/sparseml/pytorch/utils/exporter.py

Lines changed: 172 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121
import warnings
2222
from copy import deepcopy
23-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
23+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2424

2525
import numpy
2626
import onnx
@@ -42,7 +42,10 @@
4242
script_model,
4343
trace_model,
4444
)
45-
from sparseml.pytorch.utils.quantization import quantize_torch_qat_export
45+
from sparseml.pytorch.utils.quantization import (
46+
quantize_torch_qat_export,
47+
skip_onnx_input_quantize,
48+
)
4649
from sparseml.utils import clean_path, create_parent_dirs
4750

4851

@@ -152,23 +155,6 @@ def export_to_zoo(
152155
sample_originals=sample_originals,
153156
)
154157

155-
@classmethod
156-
def get_output_names(cls, out: Any):
157-
"""
158-
Get name of output tensors
159-
160-
:param out: outputs of the model
161-
:return: list of names
162-
"""
163-
output_names = None
164-
if isinstance(out, Tensor):
165-
output_names = ["output"]
166-
elif isinstance(out, Iterable):
167-
output_names = [
168-
"output_{}".format(index) for index, _ in enumerate(iter(out))
169-
]
170-
return output_names
171-
172158
def export_onnx(
173159
self,
174160
sample_batch: Any,
@@ -203,96 +189,16 @@ def export_onnx(
203189
See more on the torch.onnx.export api spec in the PyTorch docs:
204190
https://pytorch.org/docs/stable/onnx.html
205191
"""
206-
if not export_kwargs:
207-
export_kwargs = {}
208-
209-
if isinstance(sample_batch, Dict) and not isinstance(
210-
sample_batch, collections.OrderedDict
211-
):
212-
warnings.warn(
213-
"Sample inputs passed into the ONNX exporter should be in "
214-
"the same order defined in the model forward function. "
215-
"Consider using OrderedDict for this purpose.",
216-
UserWarning,
217-
)
218-
219-
sample_batch = tensors_to_device(sample_batch, "cpu")
220-
onnx_path = os.path.join(self._output_dir, name)
221-
create_parent_dirs(onnx_path)
222-
223-
with torch.no_grad():
224-
out = tensors_module_forward(
225-
sample_batch, self._module, check_feat_lab_inp=False
226-
)
227-
228-
if "input_names" not in export_kwargs:
229-
if isinstance(sample_batch, Tensor):
230-
export_kwargs["input_names"] = ["input"]
231-
elif isinstance(sample_batch, Dict):
232-
export_kwargs["input_names"] = list(sample_batch.keys())
233-
sample_batch = tuple(
234-
[sample_batch[f] for f in export_kwargs["input_names"]]
235-
)
236-
elif isinstance(sample_batch, Iterable):
237-
export_kwargs["input_names"] = [
238-
"input_{}".format(index)
239-
for index, _ in enumerate(iter(sample_batch))
240-
]
241-
if isinstance(sample_batch, List):
242-
sample_batch = tuple(
243-
sample_batch
244-
) # torch.onnx.export requires tuple
245-
246-
if "output_names" not in export_kwargs:
247-
export_kwargs["output_names"] = self.get_output_names(out)
248-
249-
# disable active quantization observers because they cannot be exported
250-
disabled_observers = []
251-
for submodule in self._module.modules():
252-
if (
253-
hasattr(submodule, "observer_enabled")
254-
and submodule.observer_enabled[0] == 1
255-
):
256-
submodule.observer_enabled[0] = 0
257-
disabled_observers.append(submodule)
258-
259-
is_quant_module = any(
260-
hasattr(submodule, "qconfig") and submodule.qconfig
261-
for submodule in self._module.modules()
262-
)
263-
batch_norms_wrapped = False
264-
if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing:
265-
# prevent batch norm fusing by adding a trivial operation before every
266-
# batch norm layer
267-
export_module = deepcopy(self._module)
268-
batch_norms_wrapped = _wrap_batch_norms(export_module)
269-
else:
270-
export_module = self._module
271-
272-
torch.onnx.export(
273-
export_module,
274-
sample_batch,
275-
onnx_path,
276-
strip_doc_string=True,
277-
verbose=False,
278-
opset_version=opset,
192+
export_onnx(
193+
module=self._module,
194+
sample_batch=sample_batch,
195+
file_path=os.path.join(self._output_dir, name),
196+
opset=opset,
197+
disable_bn_fusing=disable_bn_fusing,
198+
convert_qat=convert_qat,
279199
**export_kwargs,
280200
)
281201

282-
# re-enable disabled quantization observers
283-
for submodule in disabled_observers:
284-
submodule.observer_enabled[0] = 1
285-
286-
# clean up graph from any injected / wrapped operations
287-
if batch_norms_wrapped:
288-
onnx_model = onnx.load(onnx_path)
289-
_delete_trivial_onnx_adds(onnx_model)
290-
onnx.save(onnx_model, onnx_path)
291-
292-
if convert_qat and is_quant_module:
293-
# overwrite exported model with fully quantized version
294-
quantize_torch_qat_export(model=onnx_path, output_file_path=onnx_path)
295-
296202
def export_torchscript(
297203
self,
298204
name: str = "model.pts",
@@ -425,6 +331,166 @@ def export_samples(
425331
exp_counter += len(exported_input)
426332

427333

334+
def export_onnx(
335+
module: Module,
336+
sample_batch: Any,
337+
file_path: str,
338+
opset: int = DEFAULT_ONNX_OPSET,
339+
disable_bn_fusing: bool = True,
340+
convert_qat: bool = False,
341+
dynamic_axes: Union[str, Dict[str, List[int]]] = None,
342+
skip_input_quantize: bool = False,
343+
**export_kwargs,
344+
):
345+
"""
346+
Export an onnx file for the current module and for a sample batch.
347+
Sample batch used to feed through the model to freeze the graph for a
348+
particular execution.
349+
350+
:param module: torch Module object to export
351+
:param sample_batch: the batch to export an onnx for, handles creating the
352+
static graph for onnx as well as setting dimensions
353+
:param file_path: path to the onnx file to save
354+
:param opset: onnx opset to use for exported model. Default is 11, if torch
355+
version is 1.2 or below, default is 9
356+
:param disable_bn_fusing: torch >= 1.7.0 only. Set True to disable batch norm
357+
fusing during torch export. Default and suggested setting is True. Batch
358+
norm fusing will change the exported parameter names as well as affect
359+
sensitivity analyses of the exported graph. Additionally, the DeepSparse
360+
inference engine, and other engines, perform batch norm fusing at model
361+
compilation.
362+
:param convert_qat: if True and quantization aware training is detected in
363+
the module being exported, the resulting QAT ONNX model will be converted
364+
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
365+
is False.
366+
:param dynamic_axes: dictionary of input or output names to list of dimensions
367+
of those tensors that should be exported as dynamic. May input 'batch'
368+
to set the first dimension of all inputs and outputs to dynamic. Default
369+
is an empty dict
370+
:param skip_input_quantize: if True, the export flow will attempt to delete
371+
the first Quantize Linear Nodes(s) immediately after model input and set
372+
the model input type to UINT8. Default is False
373+
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
374+
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
375+
See more on the torch.onnx.export api spec in the PyTorch docs:
376+
https://pytorch.org/docs/stable/onnx.html
377+
"""
378+
if not export_kwargs:
379+
export_kwargs = {}
380+
381+
if isinstance(sample_batch, Dict) and not isinstance(
382+
sample_batch, collections.OrderedDict
383+
):
384+
warnings.warn(
385+
"Sample inputs passed into the ONNX exporter should be in "
386+
"the same order defined in the model forward function. "
387+
"Consider using OrderedDict for this purpose.",
388+
UserWarning,
389+
)
390+
391+
sample_batch = tensors_to_device(sample_batch, "cpu")
392+
create_parent_dirs(file_path)
393+
394+
module = deepcopy(module).cpu().eval()
395+
396+
with torch.no_grad():
397+
out = tensors_module_forward(sample_batch, module, check_feat_lab_inp=False)
398+
399+
if "input_names" not in export_kwargs:
400+
if isinstance(sample_batch, Tensor):
401+
export_kwargs["input_names"] = ["input"]
402+
elif isinstance(sample_batch, Dict):
403+
export_kwargs["input_names"] = list(sample_batch.keys())
404+
sample_batch = tuple(
405+
[sample_batch[f] for f in export_kwargs["input_names"]]
406+
)
407+
elif isinstance(sample_batch, Iterable):
408+
export_kwargs["input_names"] = [
409+
"input_{}".format(index) for index, _ in enumerate(iter(sample_batch))
410+
]
411+
if isinstance(sample_batch, List):
412+
sample_batch = tuple(sample_batch) # torch.onnx.export requires tuple
413+
414+
if "output_names" not in export_kwargs:
415+
export_kwargs["output_names"] = _get_output_names(out)
416+
417+
if dynamic_axes == "batch":
418+
dynamic_axes = {
419+
tensor_name: {0: "batch"}
420+
for tensor_name in (
421+
export_kwargs["input_names"] + export_kwargs["output_names"]
422+
)
423+
}
424+
425+
# disable active quantization observers because they cannot be exported
426+
disabled_observers = []
427+
for submodule in module.modules():
428+
if (
429+
hasattr(submodule, "observer_enabled")
430+
and submodule.observer_enabled[0] == 1
431+
):
432+
submodule.observer_enabled[0] = 0
433+
disabled_observers.append(submodule)
434+
435+
is_quant_module = any(
436+
hasattr(submodule, "qconfig") and submodule.qconfig
437+
for submodule in module.modules()
438+
)
439+
batch_norms_wrapped = False
440+
if torch.__version__ >= "1.7" and not is_quant_module and disable_bn_fusing:
441+
# prevent batch norm fusing by adding a trivial operation before every
442+
# batch norm layer
443+
batch_norms_wrapped = _wrap_batch_norms(module)
444+
445+
torch.onnx.export(
446+
module,
447+
sample_batch,
448+
file_path,
449+
strip_doc_string=True,
450+
verbose=False,
451+
opset_version=opset,
452+
dynamic_axes=dynamic_axes,
453+
**export_kwargs,
454+
)
455+
456+
# re-enable disabled quantization observers
457+
for submodule in disabled_observers:
458+
submodule.observer_enabled[0] = 1
459+
460+
# clean up graph from any injected / wrapped operations
461+
if batch_norms_wrapped:
462+
onnx_model = onnx.load(file_path)
463+
_delete_trivial_onnx_adds(onnx_model)
464+
onnx.save(onnx_model, file_path)
465+
466+
if convert_qat and is_quant_module:
467+
# overwrite exported model with fully quantized version
468+
quantize_torch_qat_export(model=file_path, output_file_path=file_path)
469+
470+
if skip_input_quantize:
471+
try:
472+
skip_onnx_input_quantize(file_path, file_path)
473+
except Exception as e:
474+
_LOGGER.warning(
475+
f"Unable to skip input QuantizeLinear op with exception {e}"
476+
)
477+
478+
479+
def _get_output_names(out: Any):
480+
"""
481+
Get name of output tensors
482+
483+
:param out: outputs of the model
484+
:return: list of names
485+
"""
486+
output_names = None
487+
if isinstance(out, Tensor):
488+
output_names = ["output"]
489+
elif isinstance(out, Iterable):
490+
output_names = ["output_{}".format(index) for index, _ in enumerate(iter(out))]
491+
return output_names
492+
493+
428494
class _AddNoOpWrapper(Module):
429495
# trivial wrapper to break-up Conv-BN blocks
430496

0 commit comments

Comments
 (0)