Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 842b1e1

Browse files
authored
quantized inputs optimization for qat exports (#110)
* quantized inputs optimization for qat exports * raise exception if no optim made * logging and unit test * negative test
1 parent d471c19 commit 842b1e1

File tree

4 files changed

+259
-3
lines changed

4 files changed

+259
-3
lines changed

integrations/ultralytics/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
521521
# Start SparseML ONNX Export
522522
#################################################################################
523523
from sparseml.pytorch.utils import ModuleExporter
524+
from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize
524525

526+
onnx_path = f"{save_dir}/model.onnx"
525527
logger.info(
526-
f"training complete, exporting ONNX to {save_dir}/model.onnx"
528+
f"training complete, exporting ONNX to {onnx_path}"
527529
)
528530
model.model[-1].export = True # do not export grid post-procesing
529531
exporter = ModuleExporter(model, save_dir)
530532
exporter.export_onnx(torch.randn((1, 3, *imgsz)), convert_qat=True)
533+
if qat:
534+
skip_onnx_input_quantize(onnx_path, onnx_path)
531535
#################################################################################
532536
# End SparseML ONNX Export
533537
#################################################################################

src/sparseml/onnx/utils/graph_editor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,78 @@ def update_node_input(
136136
node.input.append(input_id)
137137
self._input_id_to_nodes[input_id].append(node)
138138

139+
def delete_node(self, node: NodeProto):
140+
"""
141+
deletes the given node from the graph
142+
143+
:param node: node to delete
144+
"""
145+
self._model.graph.node.remove(node)
146+
self._delete_node_edges(node)
147+
148+
def delete_nodes(self, nodes: List[NodeProto]):
149+
"""
150+
deletes the given nodes from the graph
151+
:param nodes: list of nodes to delete
152+
"""
153+
node_ouptut_ids_to_delete = {node.output[0] for node in nodes}
154+
nodes_to_keep = []
155+
for node in self._model.graph.node:
156+
if node.output[0] in node_ouptut_ids_to_delete:
157+
self._delete_node_edges(node)
158+
else:
159+
nodes_to_keep.append(node)
160+
self._model.graph.ClearField("node")
161+
self._model.graph.node.extend(nodes_to_keep)
162+
163+
def delete_initializers(self, initializers: List[Union[str, TensorProto]]):
164+
"""
165+
deletes the given initializers from the model
166+
167+
:param initializers: list of initializers or initializer names to delete
168+
"""
169+
inits_to_delete = {
170+
init if isinstance(init, str) else init.name for init in initializers
171+
}
172+
inits_to_keep = []
173+
for init in self._model.graph.initializer:
174+
if init.name in inits_to_delete:
175+
# keep edge reference if nodes in the graph still point to the
176+
# initializer name
177+
if not self._input_id_to_nodes[init.name]:
178+
del self._input_id_to_nodes[init.name]
179+
del self._name_to_initializer[init.name]
180+
else:
181+
inits_to_keep.append(init)
182+
self._model.graph.ClearField("initializer")
183+
self._model.graph.initializer.extend(inits_to_keep)
184+
185+
def delete_unused_initializers(self):
186+
"""
187+
deletes tensors in the initializer list that are not listed as inputs to any node
188+
in the current graph state
189+
"""
190+
self.delete_initializers(
191+
[
192+
init
193+
for init in self._model.graph.initializer
194+
if not self._input_id_to_nodes[init.name]
195+
]
196+
) # delete inits that have no edge
197+
139198
def _store_node_edges(self, node: NodeProto):
140199
for output_id in node.output:
141200
self._output_id_to_node[output_id] = node
142201
for input_id in node.input:
143202
self._input_id_to_nodes[input_id].append(node)
144203

204+
def _delete_node_edges(self, node: NodeProto):
205+
# remove node edges from cache
206+
for output_id in node.output:
207+
del self._output_id_to_node[output_id]
208+
for input_id in node.input:
209+
self._input_id_to_nodes[input_id].remove(node)
210+
145211

146212
def update_model_param(
147213
model: ModelProto,

src/sparseml/pytorch/utils/quantization/quantize_qat_export.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818
"""
1919

2020

21+
import logging
2122
from collections import defaultdict
2223
from copy import deepcopy
23-
from typing import Any, NamedTuple, Union
24+
from typing import Any, NamedTuple, Optional, Union
2425

2526
import numpy
2627
import onnx
2728
from onnx import ModelProto, NodeProto, numpy_helper
2829

2930
from sparseml.onnx.utils import (
31+
ONNXGraph,
3032
get_batch_norm_params,
3133
get_init_by_name,
3234
get_node_attributes,
@@ -40,7 +42,15 @@
4042
)
4143

4244

43-
__all__ = ["get_quantization_params", "QuantizationParams", "quantize_torch_qat_export"]
45+
__all__ = [
46+
"get_quantization_params",
47+
"QuantizationParams",
48+
"quantize_torch_qat_export",
49+
"skip_onnx_input_quantize",
50+
]
51+
52+
53+
_LOGGER = logging.getLogger(__name__)
4454

4555

4656
"""
@@ -593,3 +603,81 @@ def quantize_torch_qat_export(
593603
onnx.save(model, output_file_path)
594604

595605
return model
606+
607+
608+
def _skip_input_quantize(model: ModelProto) -> Optional[str]:
609+
if (
610+
len(model.graph.input) != 1
611+
or model.graph.input[0].type.tensor_type.elem_type != 1
612+
):
613+
# more than 1 input or input is not FP32
614+
return (
615+
"Not modifying ONNX graph inputs - either graph has more than one "
616+
"input or input type is not FP32"
617+
)
618+
619+
input_node = model.graph.input[0]
620+
input_children = [
621+
node for node in model.graph.node if input_node.name in node.input
622+
]
623+
if not all(node.op_type == "QuantizeLinear" for node in input_children):
624+
return (
625+
"Not modifying ONNX graph inputs - only QuantizeLinear nodes may follow the "
626+
"FP32 input tensor in original graph, prior to converting to uint8"
627+
)
628+
629+
graph = ONNXGraph(model)
630+
for quantize_node in input_children:
631+
quantize_children = graph.get_node_children(quantize_node)
632+
quantize_node_id = quantize_node.output[0]
633+
for child_node in quantize_children:
634+
input_idx = [
635+
idx
636+
for idx, inp in enumerate(child_node.input)
637+
if inp == quantize_node_id
638+
]
639+
if not input_idx:
640+
continue
641+
input_idx = input_idx[0]
642+
graph.update_node_input(child_node, input_node.name, input_idx)
643+
_LOGGER.debug(
644+
f"set node with output id {child_node.output[0]} as initial node in "
645+
"graph"
646+
)
647+
648+
_LOGGER.debug(
649+
f"deleting QuantizeLinear node(s) with output id(s): "
650+
f"{[n.output for n in input_children]}"
651+
)
652+
graph.delete_nodes(input_children) # only contains references to the Quantize nodes
653+
graph.delete_unused_initializers() # cleanup
654+
input_node.type.tensor_type.elem_type = 2 # fp32 -> uint8
655+
_LOGGER.info("Model initial QuantizeLinear node(s) deleted and inputs set to uint8")
656+
657+
return None
658+
659+
660+
def skip_onnx_input_quantize(
661+
model: Union[ModelProto, str],
662+
output_file_path: Union[str, None] = None,
663+
):
664+
"""
665+
If the given model has a single FP32 input that feeds into a QuantizeLinear
666+
node, then the input will be changed to uint8 and the QuantizeLinear node will be
667+
deleted. This enables quantize graphs to take quantized inputs instead of floats.
668+
669+
If no optimization is made, a RuntimeError will be raised.
670+
671+
:param model: The model to convert, or a file path to it
672+
:param output_file_path: File path to save the converted model to
673+
"""
674+
if isinstance(model, str):
675+
model = onnx.load(model)
676+
677+
optim_error_message = _skip_input_quantize(model)
678+
679+
if optim_error_message:
680+
raise RuntimeError(optim_error_message)
681+
682+
if output_file_path:
683+
onnx.save(model, output_file_path)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import onnx
16+
import pytest
17+
from onnx import TensorProto
18+
19+
from sparseml.pytorch.utils.quantization import skip_onnx_input_quantize
20+
21+
22+
def test_skip_onnx_input_quantize():
23+
# make sample graph of fp32 input -> QuantizeLinear -> QLinearConv
24+
# verify that it is transformed to uint8 input -> QLinearConv
25+
26+
float_input = onnx.helper.make_tensor_value_info(
27+
"input", TensorProto.FLOAT, [1, 3, None, None]
28+
)
29+
quant_node = onnx.helper.make_node(
30+
"QuantizeLinear",
31+
["input", "scale", "zp"],
32+
["quant_output"],
33+
)
34+
qconv_node = onnx.helper.make_node(
35+
"QLinearConv",
36+
["quant_output", "scale", "zp", "w", "w_scale", "w_zp", "y_scale", "y_zp"],
37+
["qconv_output"],
38+
)
39+
40+
qconv_output = onnx.helper.make_tensor_value_info(
41+
"qconv_output", TensorProto.UINT8, [1, 1, None, None]
42+
)
43+
44+
graph = onnx.helper.make_graph(
45+
[quant_node, qconv_node],
46+
"test_graph",
47+
[float_input],
48+
[qconv_output],
49+
[],
50+
)
51+
model = onnx.helper.make_model(graph)
52+
53+
# initial model checks
54+
assert model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT
55+
assert len(model.graph.node) == 2
56+
assert model.graph.node[0].op_type == "QuantizeLinear"
57+
assert model.graph.node[1].op_type == "QLinearConv"
58+
59+
assert model.graph.node[0].input[0] == model.graph.input[0].name
60+
assert model.graph.node[1].input[0] == model.graph.node[0].output[0]
61+
62+
# run optimization
63+
skip_onnx_input_quantize(model)
64+
65+
# check model has uint8 inputs and no qlinear input node
66+
assert model.graph.input[0].type.tensor_type.elem_type == TensorProto.UINT8
67+
assert len(model.graph.node) == 1
68+
assert model.graph.node[0].op_type == "QLinearConv"
69+
70+
assert model.graph.node[0].input[0] == model.graph.input[0].name
71+
72+
73+
def test_skip_onnx_input_quantize_expected_exception():
74+
# test that a graph with already quantized inputs fails for this optimization
75+
76+
int_input = onnx.helper.make_tensor_value_info(
77+
"input", TensorProto.UINT8, [1, 3, None, None]
78+
)
79+
qconv_node = onnx.helper.make_node(
80+
"QLinearConv",
81+
["input", "scale", "zp", "w", "w_scale", "w_zp", "y_scale", "y_zp"],
82+
["qconv_output"],
83+
)
84+
85+
qconv_output = onnx.helper.make_tensor_value_info(
86+
"qconv_output", TensorProto.UINT8, [1, 1, None, None]
87+
)
88+
89+
graph = onnx.helper.make_graph(
90+
[qconv_node],
91+
"test_graph",
92+
[int_input],
93+
[qconv_output],
94+
[],
95+
)
96+
model = onnx.helper.make_model(graph)
97+
with pytest.raises(RuntimeError) as err:
98+
skip_onnx_input_quantize(model)

0 commit comments

Comments
 (0)