Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 5774006

Browse files
authored
Fix qat convert mobilebert (#968)
* Move ReLU folding to after the matmul conversions (so it can affect ReLU layers in FFN blocks) * Changes to QAT export to support mobileBERT * Style and quality fixes
1 parent 56b27f6 commit 5774006

File tree

1 file changed

+123
-4
lines changed

1 file changed

+123
-4
lines changed

src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py

Lines changed: 123 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,9 +1380,9 @@ def _quantize_qat_embedding(model: ModelProto):
13801380
| | |
13811381
| Gather
13821382
| |
1383-
| QuantizeLinear
1383+
| QuantizeLinear (Optional)
13841384
| |
1385-
| DequantizeLinear
1385+
| DequantizeLinear (Optional)
13861386
| |
13871387
| OUTPUT
13881388
@@ -1571,19 +1571,20 @@ def quantize_torch_qat_export(
15711571
model = deepcopy(model)
15721572

15731573
_fold_qat_conv_bns(model)
1574-
_fold_relu_quants(model)
15751574
_convert_single_constants_to_initializers(model)
15761575
_delete_repeated_qat_blocks(model)
1576+
_quantize_qat_embedding(model)
1577+
_propagate_mobilebert_embedding_quantization(model)
15771578
_convert_quantizable_matmul(model)
15781579
_convert_quantizable_matmul_and_add(model)
1580+
_fold_relu_quants(model)
15791581

15801582
# only convert to either ConvInteger or QLinearConv (legacy)
15811583
if not use_qlinearconv:
15821584
_convert_quantizable_conv_integer(model)
15831585
_convert_quantizable_ops(model, convert_qlinearconv=use_qlinearconv)
15841586

15851587
_convert_quantizable_gemm_no_activations(model)
1586-
_quantize_qat_embedding(model)
15871588
quantize_resnet_identity_add_inputs(model)
15881589
_remove_duplicate_quantize_ops(model)
15891590
_cleanup_unused_quants(model)
@@ -1719,3 +1720,121 @@ def skip_onnx_input_quantize(
17191720

17201721
if output_file_path:
17211722
onnx.save(model, output_file_path)
1723+
1724+
1725+
def _propagate_mobilebert_embedding_quantization(model: ModelProto):
1726+
"""
1727+
A pass for propagating embedding quantizations through concat
1728+
1729+
Starting with:
1730+
| GATHER (UINT8 data initializer)
1731+
| |
1732+
| DequantizeLinear
1733+
| | | |
1734+
| | Slice Slice
1735+
| | | |
1736+
| | Pad Pad
1737+
| | | |
1738+
| Concat
1739+
| |
1740+
| OUTPUT
1741+
1742+
Converts to:
1743+
| GATHER (UINT8 data initializer)
1744+
| | | |
1745+
| | Slice Slice
1746+
| | | |
1747+
| | Pad Pad
1748+
| | | |
1749+
| Concat
1750+
| |
1751+
| DequantizeLinear
1752+
| |
1753+
| OUTPUT
1754+
"""
1755+
converted_nodes = 0
1756+
gather_nodes = [n for n in model.graph.node if n.op_type in ["Gather"]]
1757+
graph = ONNXGraph(model)
1758+
for gather_node in gather_nodes:
1759+
# find quantized weight
1760+
embedding_initializer = graph.get_init_by_name(gather_node.input[0])
1761+
if not embedding_initializer:
1762+
continue
1763+
1764+
embedding_array = numpy_helper.to_array(embedding_initializer)
1765+
if embedding_array.dtype != numpy.uint8:
1766+
continue
1767+
1768+
dequant_node = graph.get_node_single_child(gather_node)
1769+
if not dequant_node or dequant_node.op_type != "DequantizeLinear":
1770+
continue
1771+
1772+
# loop through the children of the dequantize node and check if they
1773+
# are composed of slice + pad nodes and converge at the same concat node
1774+
valid = True
1775+
concat_node = None
1776+
for branch_node in graph.get_node_children(dequant_node):
1777+
if branch_node.op_type == "Slice":
1778+
pad_node = graph.get_node_single_child(branch_node)
1779+
if not pad_node or pad_node.op_type != "Pad":
1780+
valid = False
1781+
break
1782+
1783+
concat_node_ = graph.get_node_single_child(pad_node)
1784+
if not concat_node_ or concat_node_.op_type != "Concat":
1785+
valid = False
1786+
break
1787+
1788+
if concat_node is None:
1789+
concat_node = concat_node_
1790+
elif concat_node != concat_node_:
1791+
valid = False
1792+
break
1793+
elif branch_node.op_type == "Concat":
1794+
if concat_node is None:
1795+
concat_node = branch_node
1796+
elif branch_node != concat_node:
1797+
valid = False
1798+
break
1799+
else:
1800+
valid = False
1801+
break
1802+
1803+
if not valid or not concat_node:
1804+
continue
1805+
1806+
# switch position of dequantize node
1807+
for branch_node in graph.get_node_children(dequant_node):
1808+
if branch_node.op_type == "Slice":
1809+
branch_node.input[0] = gather_node.output[0]
1810+
pad_node = graph.get_node_single_child(branch_node)
1811+
pad_value = graph.get_init_by_name(pad_node.input[2])
1812+
pad_value_array = numpy_helper.to_array(pad_value)
1813+
pad_value_array = pad_value_array + 128
1814+
pad_value_array = pad_value_array.astype(numpy.uint8)
1815+
model.graph.initializer.remove(pad_value)
1816+
pad_value = numpy_helper.from_array(
1817+
pad_value_array, name=pad_value.name
1818+
)
1819+
model.graph.initializer.append(pad_value)
1820+
1821+
for id, input_name in enumerate(concat_node.input):
1822+
if input_name == dequant_node.output[0]:
1823+
break
1824+
1825+
concat_node.input[id] = gather_node.output[0]
1826+
temp = concat_node.output[0]
1827+
concat_node.output[0] = dequant_node.output[0]
1828+
dequant_node.output[0] = temp
1829+
dequant_node.input[0] = concat_node.output[0]
1830+
1831+
graph.update()
1832+
1833+
converted_nodes += 1
1834+
1835+
graph.delete_unused_initializers()
1836+
1837+
if converted_nodes > 0:
1838+
_LOGGER.info(
1839+
f"Propagated {converted_nodes} DequantizeLinear node(s) through Concat"
1840+
)

0 commit comments

Comments
 (0)