Skip to content

Commit f9e7426

Browse files
authored
fix qparams decompression (#514)
* fix qparams decompression Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * quality Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * quality Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * Add zero-point compression for asymmetric quantization Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * Add scale decompression support Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * fix tests Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * cleanup Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * minimal diff Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * quality Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * remove script Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * quality Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * minimum diff Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * added TODO Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * address reviews Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * fix compressed params tracking Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * use helper in initialize Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * quality Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * addressed reviews Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * minimum diff Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> * Address some comments Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent b019b89 commit f9e7426

File tree

4 files changed

+106
-50
lines changed

4 files changed

+106
-50
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
2121
from compressed_tensors.registry import RegistryMixin
2222
from compressed_tensors.utils import has_offloaded_params
23+
from compressed_tensors.utils.offload import (
24+
delete_offload_parameter,
25+
get_offloaded_device,
26+
register_offload_parameter,
27+
)
2328
from torch import Tensor
2429
from torch.nn import Module
2530

@@ -185,10 +190,37 @@ def decompress_module(self, module: Module):
185190
for name, parameter in module.named_parameters():
186191
compressed_data[name] = parameter
187192

188-
return self.decompress_weight(
193+
# Save references to original parameters before decompression
194+
original_scale = compressed_data.get("weight_scale")
195+
original_zp = compressed_data.get("weight_zero_point")
196+
197+
# NOTE: decompress_weight may modify compressed_data dict in-place
198+
# This is subtle but allows us to update the module's qparams with
199+
# the unpacked values.
200+
# TODO: Consider refactoring to return modified qparams explicitly
201+
result = self.decompress_weight(
189202
compressed_data=compressed_data, quantization_args=quantization_args
190203
).to(device)
191204

205+
# Update module's parameters only if they were modified
206+
for param_name, original_param in [
207+
("weight_scale", original_scale),
208+
("weight_zero_point", original_zp),
209+
]:
210+
if (
211+
param_name in compressed_data
212+
and compressed_data[param_name] is not original_param
213+
):
214+
# Delete the old parameter and register the updated one
215+
delete_offload_parameter(module, param_name)
216+
offload_device = get_offloaded_device(module)
217+
param = torch.nn.Parameter(
218+
compressed_data[param_name], requires_grad=False
219+
)
220+
register_offload_parameter(module, param_name, param, offload_device)
221+
222+
return result
223+
192224
def decompress_weight(
193225
self, compressed_data: Dict[str, Tensor], **kwargs
194226
) -> torch.Tensor:

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
21-
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
21+
from compressed_tensors.quantization import QuantizationScheme
2222
from compressed_tensors.utils import (
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
@@ -85,6 +85,7 @@ def compress(
8585
"""
8686
uncompressed_names = list(model_state.keys())
8787
compressed_dict = {}
88+
compressed_param_names = set()
8889

8990
# compress values
9091
desc = "Compressing with quantization"
@@ -119,54 +120,38 @@ def compress(
119120
device=compression_device,
120121
)
121122

122-
# update state dict
123+
# update state dict and track which params were added
123124
for key, value in compressed_values.items():
124-
compressed_dict[prefix + key] = value.to(compression_device)
125+
full_name = prefix + key
126+
compressed_dict[full_name] = value.to(compression_device)
127+
compressed_param_names.add(full_name)
125128

126129
else:
127-
# omit saving zero points for symmetric or packed quantization
128-
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
130+
# Skip qparams already added by compress_weight
131+
if name in compressed_param_names:
129132
continue
130133

131-
if name.endswith("weight_scale") and self._skip_scale():
132-
continue
134+
# for symmetric quantization, omit zero_point
135+
# manually because it wasn't handled in compress_weight
136+
if name.endswith("weight_zero_point"):
137+
module_path = name.rsplit(".", 1)[0]
138+
if (
139+
module_path in names_to_scheme
140+
and names_to_scheme[module_path].weights.symmetric
141+
):
142+
continue
143+
# Call compress_zp if available (for PackedQuantizationCompressor)
144+
if module_path in names_to_scheme and hasattr(self, "compress_zp"):
145+
value = self.compress_zp(
146+
value, names_to_scheme[module_path].weights
147+
)
148+
if value is None:
149+
continue
133150

134151
compressed_dict[name] = value.to(compression_device)
135152

136153
return compressed_dict
137154

138-
def _skip_scale(self):
139-
from compressed_tensors.compressors import NVFP4PackedCompressor
140-
141-
return isinstance(self, NVFP4PackedCompressor)
142-
143-
def _skip_zp(
144-
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
145-
) -> bool:
146-
from compressed_tensors.compressors import PackedQuantizationCompressor
147-
148-
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
149-
scheme = names_to_scheme[module_name]
150-
151-
if zp_name == "weight_zero_point":
152-
args = scheme.weights
153-
if zp_name == "input_zero_point":
154-
args = scheme.input_activations
155-
if zp_name == "output_zero_point":
156-
args = scheme.output_activations
157-
158-
symmetric = args.symmetric
159-
packable_strategies = [
160-
QuantizationStrategy.GROUP.value,
161-
QuantizationStrategy.CHANNEL.value,
162-
]
163-
packed = (
164-
isinstance(self, PackedQuantizationCompressor)
165-
and args.strategy in packable_strategies
166-
)
167-
168-
return symmetric or packed
169-
170155
def decompress(
171156
self,
172157
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],

src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def compression_param_names(self) -> Tuple[str]:
5656
return (
5757
"weight_packed",
5858
"weight_scale",
59-
"weight_zero_point",
6059
"weight_global_scale",
6160
)
6261

@@ -73,13 +72,12 @@ def compression_param_info(
7372
:param quantization_args: quantization parameters for the weight
7473
:return: dictionary mapping compressed parameter names to shape and dtype
7574
"""
76-
output = {
75+
return {
7776
"weight_packed": (
7877
torch.Size((weight_shape[0], weight_shape[1] // 2)),
7978
torch.uint8,
8079
),
8180
}
82-
return output
8381

8482
def compress_scale(
8583
self,
@@ -114,6 +112,13 @@ def compress_weight(
114112
compressed_dict["weight_scale"] = self.compress_scale(
115113
scale=scale, quantization_args=quantization_args
116114
)
115+
116+
if global_scale is None:
117+
raise ValueError(
118+
"NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). "
119+
"Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization."
120+
)
121+
117122
return compressed_dict
118123

119124
def decompress_weight(
@@ -127,6 +132,12 @@ def decompress_weight(
127132
m, n = weight.shape
128133
# TODO: use a user provided dequant dtype
129134
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
135+
136+
# cast scale dtype to match unpacked dtype for dequantization
137+
if scale.dtype != unpacked.dtype:
138+
scale = scale.to(unpacked.dtype)
139+
compressed_data["weight_scale"] = scale
140+
130141
decompressed_weight = dequantize(
131142
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
132143
)

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,25 +64,34 @@ def compression_param_info(
6464
"""
6565
pack_factor = 32 // quantization_args.num_bits
6666
packed_size = math.ceil(weight_shape[1] / pack_factor)
67-
packed_size_zp = math.ceil(weight_shape[0] / pack_factor)
6867
output = {
6968
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
7069
"weight_shape": (torch.Size((2,)), torch.int32),
7170
}
72-
if not quantization_args.symmetric and quantization_args.strategy in [
71+
72+
# Add weight_scale - always needed for quantization
73+
if quantization_args.strategy in [
7374
QuantizationStrategy.GROUP.value,
7475
QuantizationStrategy.CHANNEL.value,
7576
]:
76-
zp_factor = (
77+
shape_factor = (
7778
quantization_args.group_size
7879
if quantization_args.strategy == QuantizationStrategy.GROUP.value
7980
else weight_shape[-1]
8081
)
81-
82-
output["weight_zero_point"] = (
83-
torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)),
84-
torch.int32,
82+
scale_cols = math.ceil(weight_shape[-1] / shape_factor)
83+
output["weight_scale"] = (
84+
torch.Size((weight_shape[0], scale_cols)),
85+
quantization_args.scale_dtype,
8586
)
87+
88+
# Add weight_zero_point for asymmetric quantization
89+
if not quantization_args.symmetric:
90+
output["weight_zero_point"] = (
91+
torch.Size((math.ceil(weight_shape[0] / pack_factor), scale_cols)),
92+
torch.int32,
93+
)
94+
8695
return output
8796

8897
def compress_weight(
@@ -175,13 +184,29 @@ def decompress_weight(
175184
zero_point = unpack_from_int32(
176185
zero_point, num_bits, original_zp_shape, packed_dim=0
177186
)
187+
# Update the compressed_data dict with the unpacked zero_point
188+
compressed_data["weight_zero_point"] = zero_point
178189

179190
decompressed_weight = dequantize(
180191
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
181192
)
182193

183194
return decompressed_weight
184195

196+
def compress_zp(
197+
self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None
198+
) -> Optional[Tensor]:
199+
if zero_point is None or quantization_args.symmetric:
200+
return None
201+
if zero_point.dtype == torch.int32:
202+
return zero_point
203+
if quantization_args.strategy in [
204+
QuantizationStrategy.GROUP.value,
205+
QuantizationStrategy.CHANNEL.value,
206+
]:
207+
return pack_to_int32(zero_point, quantization_args.num_bits, packed_dim=0)
208+
return zero_point
209+
185210

186211
def pack_to_int32(
187212
value: torch.Tensor,
@@ -226,6 +251,9 @@ def pack_to_int32(
226251
if packed_dim == 0:
227252
value = value.transpose(0, 1)
228253

254+
# Ensure contiguous memory for .view() operation
255+
value = value.contiguous()
256+
229257
rows, cols = value.shape
230258
padded_cols = math.ceil(cols / pack_factor) * pack_factor
231259
pad_len = padded_cols - cols

0 commit comments

Comments
 (0)