Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 4b655d5

Browse files
authored
Fix editing a list inside a loop (#1339) (#1361)
* Fix editing a list inside a loop * Simplified adding new initializers
1 parent 5788f12 commit 4b655d5

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,17 @@ def _cast_init_int8_to_uint8(int8_init):
253253
arr_uint8 = (arr_int8.astype(numpy.int32) + 128).astype(numpy.uint8)
254254
return numpy_helper.from_array(arr_uint8, name=int8_init.name)
255255

256-
def _replace_initializer(init_old, init_new):
257-
model.graph.initializer.remove(init_old)
258-
model.graph.initializer.append(init_new)
259-
256+
to_append = []
257+
to_remove = []
260258
for init in model.graph.initializer:
261259
if init.data_type == 3: # int8 dtype
262260
init_uint8 = _cast_init_int8_to_uint8(init)
263-
_replace_initializer(init, init_uint8)
261+
to_append.append(init_uint8)
262+
to_remove.append(init)
263+
264+
for init in to_remove:
265+
model.graph.initializer.remove(init)
266+
model.graph.initializer.extend(to_append)
264267

265268

266269
def _delete_repeated_qat_blocks(model: ModelProto):
@@ -1566,7 +1569,6 @@ def quantize_torch_qat_export(
15661569
model = deepcopy(model)
15671570

15681571
_convert_single_constants_to_initializers(model)
1569-
_convert_signed_to_unsigned(model)
15701572
_fold_qat_conv_bns(model)
15711573
_delete_repeated_qat_blocks(model)
15721574
_quantize_qat_embedding(model)
@@ -1583,6 +1585,7 @@ def quantize_torch_qat_export(
15831585
_convert_quantizable_gemm_no_activations(model)
15841586
quantize_resnet_identity_add_inputs(model)
15851587
_remove_duplicate_quantize_ops(model)
1588+
_convert_signed_to_unsigned(model)
15861589

15871590
graph = ONNXGraph(model)
15881591
graph.sort_nodes_topologically()

0 commit comments

Comments
 (0)