@@ -662,11 +662,13 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
662662 | |
663663 | QuantizeLinear
664664 | |
665- | QLinearMatMul (with constant kernel)
665+ | MatMulInteger (with constant uint8 kernel)
666666 | |
667- | QLinearAdd (with constant bias)
667+ | Add ( constant bias + zero point correction )
668668 | |
669- | DequantizeLinear
669+ | Cast (INT32 -> FP32)
670+ | |
671+ | Mul (Rescale from bias scale)
670672 | |
671673 | OUTPUT
672674 """
@@ -708,6 +710,13 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
708710 ):
709711 continue
710712
713+ output_dequantize_node = graph .get_node_single_child (output_quantize_node )
714+ if (
715+ not output_dequantize_node
716+ or output_dequantize_node .op_type not in _QUANTIZE_OP_NAMES
717+ ):
718+ continue
719+
711720 input_quantize_params = get_quantization_params (
712721 model , input_quantize_node , include_target = False
713722 )
@@ -743,37 +752,35 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
743752 )
744753 model .graph .initializer .append (quantized_weight_initializer )
745754
746- # QLinearMatMul
747- # get qmatmul inputs and outputs
748- qmatmul_input = input_quantize_node .input [0 ]
749- qmatmul_inputs = [
750- qmatmul_input , # x
751- input_quantize_node .input [1 ], # x_scale
752- input_quantize_node .input [2 ], # x_zero_point
753- quantized_weight_name , # w
754- weight_quantize_node .input [1 ], # w_scale
755- weight_quantize_node .input [2 ], # w_zero_point
756- output_quantize_node .input [1 ], # y_scale
757- output_quantize_node .input [2 ], # y_zero_point
755+ # MatMulInteger
756+ # get matmulinteger inputs and outputs
757+ matmul_integer_inputs = [
758+ input_quantize_node .input [0 ], # A matrix (replaces previous dequant node)
759+ quantized_weight_name , # B matrix (quantized weight)
760+ input_quantize_node .input [2 ], # a_zero_point
761+ weight_quantize_node .input [2 ], # b_zero_point
758762 ]
759- qmatmul_output = matmul_node .output [0 ]
760- qmatmul_name = "{}_quant" .format (matmul_node .name )
763+ matmul_integer_output = matmul_node .output [0 ]
764+ matmul_integer_name = "{}_quant" .format (matmul_node .name )
761765
762766 # create qmatmul node and add it to graph
763- qmatmul_node = onnx .helper .make_node (
764- "QLinearMatMul " ,
765- qmatmul_inputs ,
766- [qmatmul_output ],
767- qmatmul_name ,
767+ matmul_integer_node = onnx .helper .make_node (
768+ "MatMulInteger " ,
769+ matmul_integer_inputs ,
770+ [matmul_integer_output ],
771+ matmul_integer_name ,
768772 )
769- model .graph .node .append (qmatmul_node )
773+ model .graph .node .append (matmul_integer_node )
770774
771- # QLinearAdd
775+ # Add bias + zero point correction
772776 # quantize bias
773777 bias_initializer = numpy_helper .to_array (bias_initializer )
774778 bias_scale = input_quantize_params .scale * weight_quantize_params .scale
775779 bias_zero_point = 0
776- quantized_bias = _quantize_array (bias_initializer , bias_scale , bias_zero_point )
780+ quantized_bias = _quantize_array (
781+ bias_initializer , bias_scale , bias_zero_point , dtype = numpy .int32
782+ )
783+
777784 quantized_bias_name = "{}.bias_quantized" .format (bias_add_node .name )
778785 quantized_bias_initializer = numpy_helper .from_array (
779786 quantized_bias , name = quantized_bias_name
@@ -793,38 +800,57 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
793800 )
794801 )
795802
796- # get qadd inputs and outputs
797- qadd_input = qmatmul_output
798- qadd_inputs = [
799- qadd_input , # x
800- output_quantize_node .input [1 ], # x_scale
801- output_quantize_node .input [2 ], # x_zero_point
802- quantized_bias_name , # b
803- quantized_bias_scale_name , # b_scale
804- quantized_bias_zero_point_name , # b_zero_point
805- output_quantize_node .input [1 ], # y_scale
806- output_quantize_node .input [2 ], # y_zero_point
803+ # get INT32 Add inputs and outputs
804+ quant_add_inputs = [
805+ matmul_integer_output , # MatMul integer outputs (INT32)
806+ quantized_bias_name , # Quantized bias (INT32)
807807 ]
808- qadd_output = output_quantize_node .output [0 ]
809- qadd_name = "{}_quant" .format (bias_add_node .name )
810- kwargs = { "domain" : "com.microsoft" }
811- # create qlinearadd node and add it to graph
808+ quant_add_output = output_quantize_node .output [0 ]
809+ quant_add_name = "{}_quant" .format (bias_add_node .name )
810+
811+ # create Add node and add it to graph
812812 qadd_node = onnx .helper .make_node (
813- "QLinearAdd" ,
814- qadd_inputs ,
815- [qadd_output ],
816- qadd_name ,
817- ** kwargs ,
813+ "Add" ,
814+ quant_add_inputs ,
815+ [quant_add_output ],
816+ quant_add_name ,
818817 )
819818 model .graph .node .append (qadd_node )
820819
820+ # create Cast node and add it to graph
821+ cast_node_name = "{}_cast" .format (bias_add_node .name )
822+ cast_node_output = "{}_cast" .format (quant_add_output )
823+ cast_node = onnx .helper .make_node (
824+ "Cast" ,
825+ [quant_add_output ],
826+ [cast_node_output ],
827+ cast_node_name ,
828+ to = getattr (onnx .TensorProto , "FLOAT" ), # get Float32 enum id
829+ )
830+ model .graph .node .append (cast_node )
831+
832+ # create Mul node for rescale
833+ mul_node_inputs = [
834+ cast_node_output , # a
835+ quantized_bias_scale_name , # b -> rescale factor
836+ ]
837+ mul_node_name = "{}_rescale_mul" .format (bias_add_node .name )
838+ mul_node = onnx .helper .make_node (
839+ "Mul" ,
840+ mul_node_inputs ,
841+ [output_dequantize_node .output [0 ]],
842+ mul_node_name ,
843+ )
844+ model .graph .node .append (mul_node )
845+
821846 # Cleanup
822847 # delete folded quantization ops
823848 delete_quant_node (model , weight_dequantize_node , keep_params = False )
824849 delete_quant_node (model , weight_quantize_node , keep_params = True )
825850 remove_node_and_params_from_graph (model , weight_transpose_node )
826851 delete_quant_node (model , input_quantize_node , keep_params = True )
827852 delete_quant_node (model , output_quantize_node , keep_params = True )
853+ delete_quant_node (model , output_dequantize_node , keep_params = True )
828854
829855 # delete original Gemm node
830856 remove_node_and_params_from_graph (model , matmul_node , keep_params = None )
@@ -838,6 +864,8 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
838864 f"Converted { conversion_count } quantizable MatMul ops with weight and bias "
839865 "to QLinearMatMul and QLinearAdd"
840866 )
867+ graph = ONNXGraph (model )
868+ graph .delete_unused_initializers ()
841869
842870
843871def _convert_quantizable_ops (model : ModelProto ):
0 commit comments