Skip to content

Commit f9f3105

Browse files
committed
fix compressed params tracking
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent 63c08ac commit f9f3105

File tree

1 file changed

+6
-15
lines changed
  • src/compressed_tensors/compressors/quantized_compressors

1 file changed

+6
-15
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def compress(
8585
"""
8686
uncompressed_names = list(model_state.keys())
8787
compressed_dict = {}
88-
compressed_prefixes = set()
88+
compressed_param_names = set()
8989

9090
# compress values
9191
desc = "Compressing with quantization"
@@ -120,24 +120,15 @@ def compress(
120120
device=compression_device,
121121
)
122122

123-
compressed_prefixes.add(prefix)
124-
125-
# update state dict
123+
# update state dict and track which params were added
126124
for key, value in compressed_values.items():
127-
compressed_dict[prefix + key] = value.to(compression_device)
125+
full_name = prefix + key
126+
compressed_dict[full_name] = value.to(compression_device)
127+
compressed_param_names.add(full_name)
128128

129129
else:
130130
# Skip qparams already added by compress_weight
131-
is_duplicate = any(
132-
name.endswith(s) and name.removesuffix(s) in compressed_prefixes
133-
for s in [
134-
"weight_scale",
135-
"weight_zero_point",
136-
"weight_global_scale",
137-
"weight_g_idx",
138-
]
139-
)
140-
if is_duplicate:
131+
if name in compressed_param_names:
141132
continue
142133

143134
# omit saving zero points for symmetric quantization

0 commit comments

Comments
 (0)