Skip to content

Commit 63c08ac

Browse files
committed
address reviews
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent d492543 commit 63c08ac

File tree

3 files changed

+23
-26
lines changed

3 files changed

+23
-26
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def compress(
8585
"""
8686
uncompressed_names = list(model_state.keys())
8787
compressed_dict = {}
88+
compressed_prefixes = set()
8889

8990
# compress values
9091
desc = "Compressing with quantization"
@@ -119,11 +120,26 @@ def compress(
119120
device=compression_device,
120121
)
121122

123+
compressed_prefixes.add(prefix)
124+
122125
# update state dict
123126
for key, value in compressed_values.items():
124127
compressed_dict[prefix + key] = value.to(compression_device)
125128

126129
else:
130+
# Skip qparams already added by compress_weight
131+
is_duplicate = any(
132+
name.endswith(s) and name.removesuffix(s) in compressed_prefixes
133+
for s in [
134+
"weight_scale",
135+
"weight_zero_point",
136+
"weight_global_scale",
137+
"weight_g_idx",
138+
]
139+
)
140+
if is_duplicate:
141+
continue
142+
127143
# omit saving zero points for symmetric quantization
128144
if name.endswith("weight_zero_point"):
129145
module_path = name.rsplit(".", 1)[0]

src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
BaseQuantizationCompressor,
2222
)
2323
from compressed_tensors.config import CompressionFormat
24-
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
24+
from compressed_tensors.quantization import QuantizationArgs
2525
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
26-
from compressed_tensors.quantization.utils import calculate_qparam_shape
2726
from torch import Tensor
2827

2928

@@ -73,32 +72,13 @@ def compression_param_info(
7372
:param quantization_args: quantization parameters for the weight
7473
:return: dictionary mapping compressed parameter names to shape and dtype
7574
"""
76-
output = {
75+
return {
7776
"weight_packed": (
7877
torch.Size((weight_shape[0], weight_shape[1] // 2)),
7978
torch.uint8,
8079
),
8180
}
8281

83-
# Add weight_scale and weight_global_scale for NVFP4/MXFP4
84-
if quantization_args is not None and quantization_args.strategy in [
85-
QuantizationStrategy.GROUP.value,
86-
QuantizationStrategy.TENSOR_GROUP.value,
87-
]:
88-
# Use centralized calculation for consistency and correctness
89-
num_groups, scale_shape = calculate_qparam_shape(
90-
weight_shape, quantization_args
91-
)
92-
output["weight_scale"] = (scale_shape, quantization_args.scale_dtype)
93-
94-
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP.value:
95-
output["weight_global_scale"] = (
96-
torch.Size((1,)),
97-
torch.float32,
98-
)
99-
100-
return output
101-
10282
def compress_scale(
10383
self,
10484
scale: Tensor,

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from compressed_tensors.quantization.lifecycle.forward import (
3636
wrap_module_forward_quantized,
3737
)
38-
from compressed_tensors.quantization.utils import strategy_cdiv
38+
from compressed_tensors.quantization.utils import calculate_qparam_shape, strategy_cdiv
3939
from compressed_tensors.utils import (
4040
disable_hf_hook,
4141
get_execution_device,
@@ -215,9 +215,10 @@ def initialize_qparams(
215215
if len(observed_shape) < 1:
216216
raise ValueError("Group quant requires at least 1 observed dimension")
217217

218-
group_size = quantization_args.group_size
219-
num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
220-
expected_shape = (*observed_shape[:-1], num_groups)
218+
# Use shared calculation to avoid floor division bugs
219+
_, expected_shape = calculate_qparam_shape(
220+
torch.Size(observed_shape), quantization_args
221+
)
221222

222223
# initialize activation ordering if applicable
223224
if actorder == ActivationOrdering.GROUP:

0 commit comments

Comments
 (0)