Skip to content

Commit 8cfb375

Browse files
committed
quality
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent c6e2d4b commit 8cfb375

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def decompress_module(self, module: Module):
190190
for name, parameter in module.named_parameters():
191191
compressed_data[name] = parameter
192192

193+
# Save references to original parameters before decompression
194+
original_scale = compressed_data.get("weight_scale")
195+
original_zp = compressed_data.get("weight_zero_point")
196+
193197
# NOTE: decompress_weight may modify compressed_data dict in-place
194198
# This is subtle but allows us to update the module's qparams with
195199
# the unpacked values.
@@ -198,9 +202,15 @@ def decompress_module(self, module: Module):
198202
compressed_data=compressed_data, quantization_args=quantization_args
199203
).to(device)
200204

201-
# Update module's parameters if they were unpacked/upcast during decompression
202-
for param_name in ["weight_zero_point", "weight_scale"]:
203-
if param_name in compressed_data and hasattr(module, param_name):
205+
# Update module's parameters only if they were actually modified during decompression
206+
for param_name, original_param in [
207+
("weight_scale", original_scale),
208+
("weight_zero_point", original_zp),
209+
]:
210+
if (
211+
param_name in compressed_data
212+
and compressed_data[param_name] is not original_param
213+
):
204214
# Delete the old parameter and register the updated one
205215
delete_offload_parameter(module, param_name)
206216
offload_device = get_offloaded_device(module)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ def initialize_qparams(
215215
if len(observed_shape) < 1:
216216
raise ValueError("Group quant requires at least 1 observed dimension")
217217

218-
# Use shared calculation to avoid floor division bugs
219-
# Note: observed_shape may contain None for dynamic dimensions (e.g., sequence length)
220-
# but calculate_qparam_shape only accesses specific indices that are concrete
218+
# Use unified helper to calculate expected shape
221219
_, expected_shape = calculate_qparam_shape(observed_shape, quantization_args)
222220

223221
# initialize activation ordering if applicable

0 commit comments

Comments
 (0)