@@ -667,9 +667,9 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
667667 | |
668668 | Add (with constant bias)
669669 | |
670- | QuantizeLinear
670+ | QuantizeLinear (Optional)
671671 | |
672- | DequantizeLinear
672+ | DequantizeLinear (Optional)
673673 | |
674674 | OUTPUT
675675 | We end up converting to:
@@ -718,19 +718,26 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
718718 bias_add_node = graph .get_node_single_child (matmul_node )
719719 if not bias_add_node or bias_add_node .op_type != "Add" :
720720 continue
721+
722+ # Optionally find output QDQ block which will be deleted
721723 output_quantize_node = graph .get_node_single_child (bias_add_node )
722724 if (
723725 not output_quantize_node
724726 or output_quantize_node .op_type not in _QUANTIZE_OP_NAMES
725727 ):
726- continue
728+ output_quantize_node = None
727729
728- output_dequantize_node = graph .get_node_single_child (output_quantize_node )
730+ output_dequantize_node = (
731+ graph .get_node_single_child (output_quantize_node )
732+ if output_quantize_node
733+ else None
734+ )
729735 if (
730736 not output_dequantize_node
731737 or output_dequantize_node .op_type not in _QUANTIZE_OP_NAMES
732738 ):
733- continue
739+ output_quantize_node = None
740+ output_dequantize_node = None
734741
735742 input_quantize_params = get_quantization_params (
736743 model , input_quantize_node , include_target = False
@@ -743,7 +750,7 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
743750 continue
744751 if input_quantize_node .op_type != "DequantizeLinear" :
745752 continue
746- if output_quantize_node .op_type != "QuantizeLinear" :
753+ if output_quantize_node and output_quantize_node .op_type != "QuantizeLinear" :
747754 continue
748755 bias_initializer = get_init_by_name (model , bias_add_node .input [1 ]) or (
749756 get_init_by_name (model , bias_add_node .input [0 ])
@@ -822,8 +829,13 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
822829 matmul_integer_output , # MatMul integer outputs (INT32)
823830 quantized_bias_name , # Quantized bias (INT32)
824831 ]
825- quant_add_output = output_quantize_node . output [ 0 ]
832+
826833 quant_add_name = "{}_quant" .format (bias_add_node .name )
834+ quant_add_output = (
835+ output_quantize_node .output [0 ]
836+ if output_quantize_node
837+ else f"{ quant_add_name } _output"
838+ )
827839
828840 # create Add node and add it to graph
829841 qadd_node = onnx .helper .make_node (
@@ -852,10 +864,15 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
852864 quantized_bias_scale_name , # b -> rescale factor
853865 ]
854866 mul_node_name = "{}_rescale_mul" .format (bias_add_node .name )
867+ mul_node_output = (
868+ output_dequantize_node .output [0 ]
869+ if output_dequantize_node
870+ else bias_add_node .output [0 ]
871+ )
855872 mul_node = onnx .helper .make_node (
856873 "Mul" ,
857874 mul_node_inputs ,
858- [output_dequantize_node . output [ 0 ] ],
875+ [mul_node_output ],
859876 mul_node_name ,
860877 )
861878 model .graph .node .append (mul_node )
@@ -865,9 +882,15 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
865882 delete_quant_node (model , weight_dequantize_node , keep_params = False )
866883 delete_quant_node (model , weight_quantize_node , keep_params = True )
867884 remove_node_and_params_from_graph (model , weight_transpose_node )
868- delete_quant_node (model , input_quantize_node , keep_params = True )
869- delete_quant_node (model , output_quantize_node , keep_params = True )
870- delete_quant_node (model , output_dequantize_node , keep_params = True )
885+
886+ # only delete input node if the matmul is the only child
887+ current_graph = ONNXGraph (model )
888+ if len (current_graph .get_node_children (input_quantize_node )) == 1 :
889+ delete_quant_node (model , input_quantize_node , keep_params = True )
890+ if output_quantize_node :
891+ delete_quant_node (model , output_quantize_node , keep_params = True )
892+ if output_dequantize_node :
893+ delete_quant_node (model , output_dequantize_node , keep_params = True )
871894
872895 # delete original Gemm node
873896 remove_node_and_params_from_graph (model , matmul_node , keep_params = None )
0 commit comments