Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 108250e

Browse files
authored
MatMulInteger + Add for quantized attention blocks (#357)
* MatMulInteger + Add for quantized attention blocks * bias correction term * remove bias correction term * add Mul op for rescaling bias cast output in FP32
1 parent e233367 commit 108250e

File tree

1 file changed

+73
-45
lines changed

1 file changed

+73
-45
lines changed

src/sparseml/pytorch/utils/quantization/quantize_qat_export.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -662,11 +662,13 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
662662
| |
663663
| QuantizeLinear
664664
| |
665-
| QLinearMatMul (with constant kernel)
665+
| MatMulInteger (with constant uint8 kernel)
666666
| |
667-
| QLinearAdd (with constant bias)
667+
| Add (constant bias + zero point correction)
668668
| |
669-
| DequantizeLinear
669+
| Cast (INT32 -> FP32)
670+
| |
671+
| Mul (Rescale from bias scale)
670672
| |
671673
| OUTPUT
672674
"""
@@ -708,6 +710,13 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
708710
):
709711
continue
710712

713+
output_dequantize_node = graph.get_node_single_child(output_quantize_node)
714+
if (
715+
not output_dequantize_node
716+
or output_dequantize_node.op_type not in _QUANTIZE_OP_NAMES
717+
):
718+
continue
719+
711720
input_quantize_params = get_quantization_params(
712721
model, input_quantize_node, include_target=False
713722
)
@@ -743,37 +752,35 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
743752
)
744753
model.graph.initializer.append(quantized_weight_initializer)
745754

746-
# QLinearMatMul
747-
# get qmatmul inputs and outputs
748-
qmatmul_input = input_quantize_node.input[0]
749-
qmatmul_inputs = [
750-
qmatmul_input, # x
751-
input_quantize_node.input[1], # x_scale
752-
input_quantize_node.input[2], # x_zero_point
753-
quantized_weight_name, # w
754-
weight_quantize_node.input[1], # w_scale
755-
weight_quantize_node.input[2], # w_zero_point
756-
output_quantize_node.input[1], # y_scale
757-
output_quantize_node.input[2], # y_zero_point
755+
# MatMulInteger
756+
# get matmulinteger inputs and outputs
757+
matmul_integer_inputs = [
758+
input_quantize_node.input[0], # A matrix (replaces previous dequant node)
759+
quantized_weight_name, # B matrix (quantized weight)
760+
input_quantize_node.input[2], # a_zero_point
761+
weight_quantize_node.input[2], # b_zero_point
758762
]
759-
qmatmul_output = matmul_node.output[0]
760-
qmatmul_name = "{}_quant".format(matmul_node.name)
763+
matmul_integer_output = matmul_node.output[0]
764+
matmul_integer_name = "{}_quant".format(matmul_node.name)
761765

762766
# create qmatmul node and add it to graph
763-
qmatmul_node = onnx.helper.make_node(
764-
"QLinearMatMul",
765-
qmatmul_inputs,
766-
[qmatmul_output],
767-
qmatmul_name,
767+
matmul_integer_node = onnx.helper.make_node(
768+
"MatMulInteger",
769+
matmul_integer_inputs,
770+
[matmul_integer_output],
771+
matmul_integer_name,
768772
)
769-
model.graph.node.append(qmatmul_node)
773+
model.graph.node.append(matmul_integer_node)
770774

771-
# QLinearAdd
775+
# Add bias + zero point correction
772776
# quantize bias
773777
bias_initializer = numpy_helper.to_array(bias_initializer)
774778
bias_scale = input_quantize_params.scale * weight_quantize_params.scale
775779
bias_zero_point = 0
776-
quantized_bias = _quantize_array(bias_initializer, bias_scale, bias_zero_point)
780+
quantized_bias = _quantize_array(
781+
bias_initializer, bias_scale, bias_zero_point, dtype=numpy.int32
782+
)
783+
777784
quantized_bias_name = "{}.bias_quantized".format(bias_add_node.name)
778785
quantized_bias_initializer = numpy_helper.from_array(
779786
quantized_bias, name=quantized_bias_name
@@ -793,38 +800,57 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
793800
)
794801
)
795802

796-
# get qadd inputs and outputs
797-
qadd_input = qmatmul_output
798-
qadd_inputs = [
799-
qadd_input, # x
800-
output_quantize_node.input[1], # x_scale
801-
output_quantize_node.input[2], # x_zero_point
802-
quantized_bias_name, # b
803-
quantized_bias_scale_name, # b_scale
804-
quantized_bias_zero_point_name, # b_zero_point
805-
output_quantize_node.input[1], # y_scale
806-
output_quantize_node.input[2], # y_zero_point
803+
# get INT32 Add inputs and outputs
804+
quant_add_inputs = [
805+
matmul_integer_output, # MatMul integer outputs (INT32)
806+
quantized_bias_name, # Quantized bias (INT32)
807807
]
808-
qadd_output = output_quantize_node.output[0]
809-
qadd_name = "{}_quant".format(bias_add_node.name)
810-
kwargs = {"domain": "com.microsoft"}
811-
# create qlinearadd node and add it to graph
808+
quant_add_output = output_quantize_node.output[0]
809+
quant_add_name = "{}_quant".format(bias_add_node.name)
810+
811+
# create Add node and add it to graph
812812
qadd_node = onnx.helper.make_node(
813-
"QLinearAdd",
814-
qadd_inputs,
815-
[qadd_output],
816-
qadd_name,
817-
**kwargs,
813+
"Add",
814+
quant_add_inputs,
815+
[quant_add_output],
816+
quant_add_name,
818817
)
819818
model.graph.node.append(qadd_node)
820819

820+
# create Cast node and add it to graph
821+
cast_node_name = "{}_cast".format(bias_add_node.name)
822+
cast_node_output = "{}_cast".format(quant_add_output)
823+
cast_node = onnx.helper.make_node(
824+
"Cast",
825+
[quant_add_output],
826+
[cast_node_output],
827+
cast_node_name,
828+
to=getattr(onnx.TensorProto, "FLOAT"), # get Float32 enum id
829+
)
830+
model.graph.node.append(cast_node)
831+
832+
# create Mul node for rescale
833+
mul_node_inputs = [
834+
cast_node_output, # a
835+
quantized_bias_scale_name, # b -> rescale factor
836+
]
837+
mul_node_name = "{}_rescale_mul".format(bias_add_node.name)
838+
mul_node = onnx.helper.make_node(
839+
"Mul",
840+
mul_node_inputs,
841+
[output_dequantize_node.output[0]],
842+
mul_node_name,
843+
)
844+
model.graph.node.append(mul_node)
845+
821846
# Cleanup
822847
# delete folded quantization ops
823848
delete_quant_node(model, weight_dequantize_node, keep_params=False)
824849
delete_quant_node(model, weight_quantize_node, keep_params=True)
825850
remove_node_and_params_from_graph(model, weight_transpose_node)
826851
delete_quant_node(model, input_quantize_node, keep_params=True)
827852
delete_quant_node(model, output_quantize_node, keep_params=True)
853+
delete_quant_node(model, output_dequantize_node, keep_params=True)
828854

829855
# delete original Gemm node
830856
remove_node_and_params_from_graph(model, matmul_node, keep_params=None)
@@ -838,6 +864,8 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
838864
f"Converted {conversion_count} quantizable MatMul ops with weight and bias "
839865
"to QLinearMatMul and QLinearAdd"
840866
)
867+
graph = ONNXGraph(model)
868+
graph.delete_unused_initializers()
841869

842870

843871
def _convert_quantizable_ops(model: ModelProto):

0 commit comments

Comments
 (0)