Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 01c86ed

Browse files
authored
option to convert PyTorch QAT graphs to fully quantized on ONNX export (#107)
1 parent f366825 commit 01c86ed

File tree

12 files changed

+30
-15
lines changed

12 files changed

+30
-15
lines changed

examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@
319319
"\n",
320320
"Once the model is saved as an ONNX file, it is ready to be used for inference with the DeepSparse Engine.\n",
321321
"\n",
322-
"If exporting the model only to PyTorch for inference, the graph can be converted to fully quantized in PyTorch only using `torch.quantization.convert`, however the resulting model will not be compatible with ONNX conversion."
322+
"Normally, exporting a QAT model from PyTorch to ONNX will create a graph with \"fake quantized\" operations that represent the QAT graph. By setting `convert_qat=True` in our exporter, a function will automatically be called to convert this exported model to a fully quantized graph that will contain desired quantized structure."
323323
]
324324
},
325325
{
@@ -330,7 +330,6 @@
330330
"source": [
331331
"import os\n",
332332
"from sparseml.pytorch.utils import ModuleExporter\n",
333-
"from sparseml.pytorch.optim.quantization import quantize_torch_qat_export\n",
334333
"\n",
335334
"save_dir = \"pytorch_sparse_quantized_transfer_learning\"\n",
336335
"qat_onnx_graph_name = \"resnet50_imagenette_pruned_qat.onnx\"\n",
@@ -339,13 +338,9 @@
339338
"exporter = ModuleExporter(model, output_dir=save_dir)\n",
340339
"exporter.export_pytorch(name=\"resnet50_imagenette_pruned_qat.pth\")\n",
341340
"exporter.export_onnx(\n",
342-
" torch.randn(1, 3, 224, 224), name=qat_onnx_graph_name\n",
341+
" torch.randn(1, 3, 224, 224), name=qat_onnx_graph_name, convert_qat=True\n",
343342
")\n",
344343
"\n",
345-
"\n",
346-
"# convert QAT graph to fully quantized operators\n",
347-
"quantize_torch_qat_export(os.path.join(save_dir, qat_onnx_graph_name), output_file_path=quantized_onnx_path)\n",
348-
"\n",
349344
"print(f\"Sparse-Quantized ONNX model saved to {quantized_onnx_path}\")"
350345
]
351346
},

integrations/pytorch-torchvision/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def main(args):
434434
########################
435435
exporter = ModuleExporter(model, save_dir)
436436
sample_input = torch.randn(image_shape).unsqueeze(0) # sample batch for ONNX export
437-
exporter.export_onnx(sample_input)
437+
exporter.export_onnx(sample_input, convert_qat=True)
438438
exporter.export_pytorch()
439439
print("Model ONNX export and PyTorch weights saved to {}".format(save_dir))
440440

integrations/timm/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,10 @@ def main():
696696
f"training complete, exporting ONNX to {output_dir}/model.onnx"
697697
)
698698
exporter = ModuleExporter(model, output_dir)
699-
exporter.export_onnx(torch.randn((1, *data_config["input_size"])))
699+
exporter.export_onnx(
700+
torch.randn((1, *data_config["input_size"])),
701+
convert_qat=True
702+
)
700703
#################################################################################
701704
# End SparseML ONNX Export
702705
#################################################################################

integrations/ultralytics/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
527527
)
528528
model.model[-1].export = True # do not export grid post-procesing
529529
exporter = ModuleExporter(model, save_dir)
530-
exporter.export_onnx(torch.randn((1, 3, *imgsz)))
530+
exporter.export_onnx(torch.randn((1, 3, *imgsz)), convert_qat=True)
531531
#################################################################################
532532
# End SparseML ONNX Export
533533
#################################################################################

scripts/pytorch_vision.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,7 @@ def _save_model_training(
10421042
save_dir: str,
10431043
epoch: int,
10441044
val_res: Union[ModuleRunResults, None],
1045+
convert_qat: bool = False,
10451046
):
10461047
LOGGER.info(
10471048
"Saving model for epoch {} and val_loss {} to {} for {}".format(
@@ -1050,7 +1051,11 @@ def _save_model_training(
10501051
)
10511052
exporter = ModuleExporter(model, save_dir)
10521053
exporter.export_pytorch(optim, epoch, "{}.pth".format(save_name))
1053-
exporter.export_onnx(torch.randn(1, *input_shape), "{}.onnx".format(save_name))
1054+
exporter.export_onnx(
1055+
torch.randn(1, *input_shape),
1056+
"{}.onnx".format(save_name),
1057+
convert_qat=convert_qat,
1058+
)
10541059

10551060
info_path = os.path.join(save_dir, "{}.txt".format(save_name))
10561061

@@ -1185,8 +1190,10 @@ def train(args, model, train_loader, val_loader, input_shape, save_dir, loggers)
11851190
# export the final model
11861191
LOGGER.info("completed...")
11871192
if args.is_main_process:
1193+
# only convert qat -> quantized ONNX graph for finalized model
1194+
# TODO: change this to all checkpoints when conversion times improve
11881195
_save_model_training(
1189-
model, optim, input_shape, "model", save_dir, epoch, val_res
1196+
model, optim, input_shape, "model", save_dir, epoch, val_res, True
11901197
)
11911198

11921199
LOGGER.info("layer sparsities:")
@@ -1222,7 +1229,7 @@ def export(args, model, val_loader, save_dir):
12221229
if not onnx_exported:
12231230
# export onnx file using first sample for graph freezing
12241231
LOGGER.info("exporting onnx in {}".format(save_dir))
1225-
exporter.export_onnx(data[0], opset=args.onnx_opset)
1232+
exporter.export_onnx(data[0], opset=args.onnx_opset, convert_qat=True)
12261233
onnx_exported = True
12271234

12281235
if args.num_samples > 0:

src/sparseml/pytorch/optim/modifier_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from sparseml.optim import ModifierProp
3636
from sparseml.pytorch.optim.modifier import PyTorchModifierYAML, ScheduledModifier
37-
from sparseml.pytorch.optim.quantization import (
37+
from sparseml.pytorch.utils.quantization import (
3838
add_quant_dequant,
3939
fuse_module_conv_bn_relus,
4040
get_qat_qconfig,

src/sparseml/pytorch/utils/exporter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
script_model,
4242
trace_model,
4343
)
44+
from sparseml.pytorch.utils.quantization import quantize_torch_qat_export
4445
from sparseml.utils import clean_path, create_parent_dirs
4546

4647

@@ -156,6 +157,7 @@ def export_onnx(
156157
name: str = "model.onnx",
157158
opset: int = DEFAULT_ONNX_OPSET,
158159
disable_bn_fusing: bool = True,
160+
convert_qat: bool = False,
159161
):
160162
"""
161163
Export an onnx file for the current module and for a sample batch.
@@ -173,6 +175,10 @@ def export_onnx(
173175
sensitivity analyses of the exported graph. Additionally, the DeepSparse
174176
inference engine, and other engines, perform batch norm fusing at model
175177
compilation.
178+
:param convert_qat: if True and quantization aware training is detected in
179+
the module being exported, the resulting QAT ONNX model will be converted
180+
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
181+
is False.
176182
"""
177183
sample_batch = tensors_to_device(sample_batch, "cpu")
178184
onnx_path = os.path.join(self._output_dir, name)
@@ -241,6 +247,10 @@ def export_onnx(
241247
_delete_trivial_onnx_adds(onnx_model)
242248
onnx.save(onnx_model, onnx_path)
243249

250+
if convert_qat and is_quant_module:
251+
# overwrite exported model with fully quantized version
252+
quantize_torch_qat_export(model=onnx_path, output_file_path=onnx_path)
253+
244254
def export_torchscript(
245255
self,
246256
name: str = "model.pts",
File renamed without changes.
File renamed without changes.

src/sparseml/pytorch/optim/quantization/quantize_qat_export.py renamed to src/sparseml/pytorch/utils/quantization/quantize_qat_export.py

File renamed without changes.

0 commit comments

Comments
 (0)