@@ -1380,9 +1380,9 @@ def _quantize_qat_embedding(model: ModelProto):
13801380 | | |
13811381 | Gather
13821382 | |
1383- | QuantizeLinear
1383+ | QuantizeLinear (Optional)
13841384 | |
1385- | DequantizeLinear
1385+ | DequantizeLinear (Optional)
13861386 | |
13871387 | OUTPUT
13881388
@@ -1571,19 +1571,20 @@ def quantize_torch_qat_export(
15711571 model = deepcopy (model )
15721572
15731573 _fold_qat_conv_bns (model )
1574- _fold_relu_quants (model )
15751574 _convert_single_constants_to_initializers (model )
15761575 _delete_repeated_qat_blocks (model )
1576+ _quantize_qat_embedding (model )
1577+ _propagate_mobilebert_embedding_quantization (model )
15771578 _convert_quantizable_matmul (model )
15781579 _convert_quantizable_matmul_and_add (model )
1580+ _fold_relu_quants (model )
15791581
15801582 # only convert to either ConvInteger or QLinearConv (legacy)
15811583 if not use_qlinearconv :
15821584 _convert_quantizable_conv_integer (model )
15831585 _convert_quantizable_ops (model , convert_qlinearconv = use_qlinearconv )
15841586
15851587 _convert_quantizable_gemm_no_activations (model )
1586- _quantize_qat_embedding (model )
15871588 quantize_resnet_identity_add_inputs (model )
15881589 _remove_duplicate_quantize_ops (model )
15891590 _cleanup_unused_quants (model )
@@ -1719,3 +1720,121 @@ def skip_onnx_input_quantize(
17191720
17201721 if output_file_path :
17211722 onnx .save (model , output_file_path )
1723+
1724+
1725+ def _propagate_mobilebert_embedding_quantization (model : ModelProto ):
1726+ """
1727+ A pass for propagating embedding quantizations through concat
1728+
1729+ Starting with:
1730+ | GATHER (UINT8 data initializer)
1731+ | |
1732+ | DequantizeLinear
1733+ | | | |
1734+ | | Slice Slice
1735+ | | | |
1736+ | | Pad Pad
1737+ | | | |
1738+ | Concat
1739+ | |
1740+ | OUTPUT
1741+
1742+ Converts to:
1743+ | GATHER (UINT8 data initializer)
1744+ | | | |
1745+ | | Slice Slice
1746+ | | | |
1747+ | | Pad Pad
1748+ | | | |
1749+ | Concat
1750+ | |
1751+ | DequantizeLinear
1752+ | |
1753+ | OUTPUT
1754+ """
1755+ converted_nodes = 0
1756+ gather_nodes = [n for n in model .graph .node if n .op_type in ["Gather" ]]
1757+ graph = ONNXGraph (model )
1758+ for gather_node in gather_nodes :
1759+ # find quantized weight
1760+ embedding_initializer = graph .get_init_by_name (gather_node .input [0 ])
1761+ if not embedding_initializer :
1762+ continue
1763+
1764+ embedding_array = numpy_helper .to_array (embedding_initializer )
1765+ if embedding_array .dtype != numpy .uint8 :
1766+ continue
1767+
1768+ dequant_node = graph .get_node_single_child (gather_node )
1769+ if not dequant_node or dequant_node .op_type != "DequantizeLinear" :
1770+ continue
1771+
1772+ # loop through the children of the dequantize node and check if they
1773+ # are composed of slice + pad nodes and converge at the same concat node
1774+ valid = True
1775+ concat_node = None
1776+ for branch_node in graph .get_node_children (dequant_node ):
1777+ if branch_node .op_type == "Slice" :
1778+ pad_node = graph .get_node_single_child (branch_node )
1779+ if not pad_node or pad_node .op_type != "Pad" :
1780+ valid = False
1781+ break
1782+
1783+ concat_node_ = graph .get_node_single_child (pad_node )
1784+ if not concat_node_ or concat_node_ .op_type != "Concat" :
1785+ valid = False
1786+ break
1787+
1788+ if concat_node is None :
1789+ concat_node = concat_node_
1790+ elif concat_node != concat_node_ :
1791+ valid = False
1792+ break
1793+ elif branch_node .op_type == "Concat" :
1794+ if concat_node is None :
1795+ concat_node = branch_node
1796+ elif branch_node != concat_node :
1797+ valid = False
1798+ break
1799+ else :
1800+ valid = False
1801+ break
1802+
1803+ if not valid or not concat_node :
1804+ continue
1805+
1806+ # switch position of dequantize node
1807+ for branch_node in graph .get_node_children (dequant_node ):
1808+ if branch_node .op_type == "Slice" :
1809+ branch_node .input [0 ] = gather_node .output [0 ]
1810+ pad_node = graph .get_node_single_child (branch_node )
1811+ pad_value = graph .get_init_by_name (pad_node .input [2 ])
1812+ pad_value_array = numpy_helper .to_array (pad_value )
1813+ pad_value_array = pad_value_array + 128
1814+ pad_value_array = pad_value_array .astype (numpy .uint8 )
1815+ model .graph .initializer .remove (pad_value )
1816+ pad_value = numpy_helper .from_array (
1817+ pad_value_array , name = pad_value .name
1818+ )
1819+ model .graph .initializer .append (pad_value )
1820+
1821+ for id , input_name in enumerate (concat_node .input ):
1822+ if input_name == dequant_node .output [0 ]:
1823+ break
1824+
1825+ concat_node .input [id ] = gather_node .output [0 ]
1826+ temp = concat_node .output [0 ]
1827+ concat_node .output [0 ] = dequant_node .output [0 ]
1828+ dequant_node .output [0 ] = temp
1829+ dequant_node .input [0 ] = concat_node .output [0 ]
1830+
1831+ graph .update ()
1832+
1833+ converted_nodes += 1
1834+
1835+ graph .delete_unused_initializers ()
1836+
1837+ if converted_nodes > 0 :
1838+ _LOGGER .info (
1839+ f"Propagated { converted_nodes } DequantizeLinear node(s) through Concat"
1840+ )
0 commit comments