@@ -253,14 +253,17 @@ def _cast_init_int8_to_uint8(int8_init):
253253 arr_uint8 = (arr_int8 .astype (numpy .int32 ) + 128 ).astype (numpy .uint8 )
254254 return numpy_helper .from_array (arr_uint8 , name = int8_init .name )
255255
256- def _replace_initializer (init_old , init_new ):
257- model .graph .initializer .remove (init_old )
258- model .graph .initializer .append (init_new )
259-
256+ to_append = []
257+ to_remove = []
260258 for init in model .graph .initializer :
261259 if init .data_type == 3 : # int8 dtype
262260 init_uint8 = _cast_init_int8_to_uint8 (init )
263- _replace_initializer (init , init_uint8 )
261+ to_append .append (init_uint8 )
262+ to_remove .append (init )
263+
264+ for init in to_remove :
265+ model .graph .initializer .remove (init )
266+ model .graph .initializer .extend (to_append )
264267
265268
266269def _delete_repeated_qat_blocks (model : ModelProto ):
@@ -1566,7 +1569,6 @@ def quantize_torch_qat_export(
15661569 model = deepcopy (model )
15671570
15681571 _convert_single_constants_to_initializers (model )
1569- _convert_signed_to_unsigned (model )
15701572 _fold_qat_conv_bns (model )
15711573 _delete_repeated_qat_blocks (model )
15721574 _quantize_qat_embedding (model )
@@ -1583,6 +1585,7 @@ def quantize_torch_qat_export(
15831585 _convert_quantizable_gemm_no_activations (model )
15841586 quantize_resnet_identity_add_inputs (model )
15851587 _remove_duplicate_quantize_ops (model )
1588+ _convert_signed_to_unsigned (model )
15861589
15871590 graph = ONNXGraph (model )
15881591 graph .sort_nodes_topologically ()
0 commit comments