From 9a1aba1520d0ce01916deb563b7ef28f7d5d2f7a Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 10 Nov 2025 21:21:47 +0000 Subject: [PATCH 01/22] fix qparams decompression Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 18 +++++++++++++++++- .../compressors/quantized_compressors/base.py | 12 +----------- .../quantized_compressors/fp4_quantized.py | 6 ++++++ .../quantized_compressors/pack_quantized.py | 2 ++ 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 3a8a97eb9..d565de7c2 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -20,6 +20,11 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig from compressed_tensors.registry import RegistryMixin from compressed_tensors.utils import has_offloaded_params +from compressed_tensors.utils.offload import ( + delete_offload_parameter, + get_offloaded_device, + register_offload_parameter, +) from torch import Tensor from torch.nn import Module @@ -185,10 +190,21 @@ def decompress_module(self, module: Module): for name, parameter in module.named_parameters(): compressed_data[name] = parameter - return self.decompress_weight( + result = self.decompress_weight( compressed_data=compressed_data, quantization_args=quantization_args ).to(device) + # Update module's parameters if they were unpacked/upcast during decompression + for param_name in ["weight_zero_point", "weight_scale"]: + if param_name in compressed_data and hasattr(module, param_name): + # Delete the old parameter and register the updated one + delete_offload_parameter(module, param_name) + offload_device = get_offloaded_device(module) + param = torch.nn.Parameter(compressed_data[param_name], requires_grad=False) + register_offload_parameter(module, param_name, param, offload_device) + + return result + def decompress_weight( self, compressed_data: Dict[str, Tensor], **kwargs ) -> torch.Tensor: diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 19f6c9c0c..ad69708d2 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -155,17 +155,7 @@ def _skip_zp( if zp_name == "output_zero_point": args = scheme.output_activations - symmetric = args.symmetric - packable_strategies = [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ] - packed = ( - isinstance(self, PackedQuantizationCompressor) - and args.strategy in packable_strategies - ) - - return symmetric or packed + return args.symmetric def decompress( self, diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index dd3c2a463..9572fc288 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -127,6 +127,12 @@ def decompress_weight( m, n = weight.shape # TODO: use a user provided dequant dtype unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + + # cast scale dtype to match unpacked dtype for dequantization + if scale.dtype != unpacked.dtype: + scale = scale.to(unpacked.dtype) + compressed_data["weight_scale"] = scale + decompressed_weight = dequantize( x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 07da50c7f..c41a38c83 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -175,6 +175,8 @@ def decompress_weight( zero_point = unpack_from_int32( zero_point, num_bits, original_zp_shape, packed_dim=0 ) + # Update the compressed_data dict with the unpacked zero_point + compressed_data["weight_zero_point"] = zero_point decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx From b8e3716301d42b801cce987e0804a760ab1d9721 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 10 Nov 2025 21:29:28 +0000 Subject: [PATCH 02/22] quality Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index d565de7c2..be607f4f0 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -200,7 +200,9 @@ def decompress_module(self, module: Module): # Delete the old parameter and register the updated one delete_offload_parameter(module, param_name) offload_device = get_offloaded_device(module) - param = torch.nn.Parameter(compressed_data[param_name], requires_grad=False) + param = torch.nn.Parameter( + compressed_data[param_name], requires_grad=False + ) register_offload_parameter(module, param_name, param, offload_device) return result From 2cec6a265cdb33230b4a90fa5d25a0feed1370d7 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 10 Nov 2025 21:35:42 +0000 Subject: [PATCH 03/22] quality Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index ad69708d2..7a8f201c4 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -18,7 +18,7 @@ import torch from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy +from compressed_tensors.quantization import QuantizationScheme from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, get_nested_weight_mappings, @@ -143,8 +143,6 @@ def _skip_scale(self): def _skip_zp( self, name: str, names_to_scheme: Dict[str, QuantizationScheme] ) -> bool: - from compressed_tensors.compressors import PackedQuantizationCompressor - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) scheme = names_to_scheme[module_name] From 7473b17130fb0ff4b763fc9d38b99d4d979afcdc Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 18 Nov 2025 19:38:32 +0000 Subject: [PATCH 04/22] Add zero-point compression for asymmetric quantization Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 33 +++++++++---------- .../quantized_compressors/pack_quantized.py | 16 +++++++++ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 7a8f201c4..71913e6f7 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -124,9 +124,21 @@ def compress( compressed_dict[prefix + key] = value.to(compression_device) else: - # omit saving zero points for symmetric or packed quantization - if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme): - continue + # omit saving zero points for symmetric quantization + if name.endswith("weight_zero_point"): + module_path = name.rsplit(".", 1)[0] + if ( + module_path in names_to_scheme + and names_to_scheme[module_path].weights.symmetric + ): + continue + # Call compress_zp if available (for PackedQuantizationCompressor) + if module_path in names_to_scheme and hasattr(self, "compress_zp"): + value = self.compress_zp( + value, names_to_scheme[module_path].weights + ) + if value is None: + continue if name.endswith("weight_scale") and self._skip_scale(): continue @@ -140,21 +152,6 @@ def _skip_scale(self): return isinstance(self, NVFP4PackedCompressor) - def _skip_zp( - self, name: str, names_to_scheme: Dict[str, QuantizationScheme] - ) -> bool: - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) - scheme = names_to_scheme[module_name] - - if zp_name == "weight_zero_point": - args = scheme.weights - if zp_name == "input_zero_point": - args = scheme.input_activations - if zp_name == "output_zero_point": - args = scheme.output_activations - - return args.symmetric - def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index c41a38c83..d8560f3a7 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -184,6 +184,22 @@ def decompress_weight( return decompressed_weight + def compress_zp( + self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None + ) -> Optional[Tensor]: + if zero_point is None or quantization_args.symmetric: + return None + if zero_point.dtype == torch.int32: + return zero_point + if quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + return pack_to_int32( + zero_point, quantization_args.num_bits, packed_dim=0 + ).contiguous() + return zero_point + def pack_to_int32( value: torch.Tensor, From 90e4655a7a8792cb3f019c842441f77bca376f21 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 9 Dec 2025 18:12:46 +0000 Subject: [PATCH 05/22] Add scale decompression support Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 8 --- .../quantized_compressors/fp4_quantized.py | 27 ++++++- .../quantized_compressors/naive_quantized.py | 12 +++- .../quantized_compressors/pack_quantized.py | 70 ++++++++++++------- .../quantization/utils/helpers.py | 45 ++++++++++++ .../test_model_compressor.py | 10 ++- .../quantized_compressors/test_fp8_quant.py | 4 +- .../quantized_compressors/test_int_quant.py | 7 +- 8 files changed, 138 insertions(+), 45 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 71913e6f7..1700acc26 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -140,18 +140,10 @@ def compress( if value is None: continue - if name.endswith("weight_scale") and self._skip_scale(): - continue - compressed_dict[name] = value.to(compression_device) return compressed_dict - def _skip_scale(self): - from compressed_tensors.compressors import NVFP4PackedCompressor - - return isinstance(self, NVFP4PackedCompressor) - def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index 9572fc288..3db28be96 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -21,8 +21,9 @@ BaseQuantizationCompressor, ) from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize +from compressed_tensors.quantization.utils import calculate_qparam_shape from torch import Tensor @@ -56,7 +57,6 @@ def compression_param_names(self) -> Tuple[str]: return ( "weight_packed", "weight_scale", - "weight_zero_point", "weight_global_scale", ) @@ -79,6 +79,24 @@ def compression_param_info( torch.uint8, ), } + + # Add weight_scale and weight_global_scale for NVFP4/MXFP4 + if quantization_args is not None and quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.TENSOR_GROUP.value, + ]: + # Use centralized calculation for consistency and correctness + num_groups, scale_shape = calculate_qparam_shape( + weight_shape, quantization_args + ) + output["weight_scale"] = (scale_shape, quantization_args.scale_dtype) + + if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP.value: + output["weight_global_scale"] = ( + torch.Size((1,)), + torch.float32, + ) + return output def compress_scale( @@ -114,6 +132,11 @@ def compress_weight( compressed_dict["weight_scale"] = self.compress_scale( scale=scale, quantization_args=quantization_args ) + + # Include global_scale if provided (for TENSOR_GROUP strategy) + if global_scale is not None: + compressed_dict["weight_global_scale"] = global_scale + return compressed_dict def decompress_weight( diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index 4b93a93c2..c297c6ba3 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -111,7 +111,17 @@ def compress_weight( if device is not None: quantized_weight = quantized_weight.to(device) - return {"weight": quantized_weight} + compressed_dict = {"weight": quantized_weight} + + # Include scale, zero_point, and g_idx if they exist + if scale is not None: + compressed_dict["weight_scale"] = scale + if zero_point is not None: + compressed_dict["weight_zero_point"] = zero_point + if g_idx is not None: + compressed_dict["weight_g_idx"] = g_idx + + return compressed_dict def decompress_weight( self, diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index d8560f3a7..0b874daca 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -22,7 +22,7 @@ from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize -from compressed_tensors.quantization.utils import can_quantize +from compressed_tensors.quantization.utils import calculate_qparam_shape, can_quantize from torch import Tensor @@ -69,20 +69,26 @@ def compression_param_info( "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32), "weight_shape": (torch.Size((2,)), torch.int32), } - if not quantization_args.symmetric and quantization_args.strategy in [ + + # Add weight_scale - always needed for quantization + if quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - zp_factor = ( - quantization_args.group_size - if quantization_args.strategy == QuantizationStrategy.GROUP.value - else weight_shape[-1] + # Use centralized calculation for consistency and correctness + num_groups, scale_shape = calculate_qparam_shape( + weight_shape, quantization_args ) + output["weight_scale"] = (scale_shape, quantization_args.scale_dtype) + + # Add weight_zero_point for asymmetric quantization + # Zero point has same num_groups as scale, but with packed rows + if not quantization_args.symmetric: + output["weight_zero_point"] = ( + torch.Size((packed_size_zp, num_groups)), + torch.int32, + ) - output["weight_zero_point"] = ( - torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)), - torch.int32, - ) return output def compress_weight( @@ -126,7 +132,7 @@ def compress_weight( packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits) - weight_shape = torch.tensor(weight.shape) + weight_shape = torch.tensor(weight.shape, dtype=torch.int32) if device is not None: packed_weight = packed_weight.to(device) weight_shape = weight_shape.to(device) @@ -134,14 +140,28 @@ def compress_weight( compressed_dict["weight_shape"] = weight_shape compressed_dict["weight_packed"] = packed_weight - if not quantization_args.symmetric and quantization_args.strategy in [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ]: - packed_zp = pack_to_int32( - zero_point, quantization_args.num_bits, packed_dim=0 - ) - compressed_dict["weight_zero_point"] = packed_zp.contiguous() + # Include scale if provided + if scale is not None: + compressed_dict["weight_scale"] = scale + + # Include zero_point if provided + if zero_point is not None: + if not quantization_args.symmetric and quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + packed_zp = pack_to_int32( + zero_point, quantization_args.num_bits, packed_dim=0 + ) + compressed_dict["weight_zero_point"] = packed_zp.contiguous() + else: + # For symmetric or other strategies, include unpacked zero_point + compressed_dict["weight_zero_point"] = zero_point + + # Include g_idx if provided + if g_idx is not None: + compressed_dict["weight_g_idx"] = g_idx + return compressed_dict def decompress_weight( @@ -172,11 +192,13 @@ def decompress_weight( zero_point is not None ), "Asymmetric quantization requires zero-point values" original_zp_shape = (original_shape[0], scale.shape[-1]) - zero_point = unpack_from_int32( - zero_point, num_bits, original_zp_shape, packed_dim=0 - ) - # Update the compressed_data dict with the unpacked zero_point - compressed_data["weight_zero_point"] = zero_point + # Only unpack if it's still packed (int32) + if zero_point.dtype == torch.int32: + zero_point = unpack_from_int32( + zero_point, num_bits, original_zp_shape, packed_dim=0 + ) + # Update the compressed_data dict with the unpacked zero_point + compressed_data["weight_zero_point"] = zero_point decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 59c5f245a..3a8535d3f 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -53,6 +53,7 @@ "calculate_qparams", "generate_gparam", "strategy_cdiv", + "calculate_qparam_shape", ] # target the self_attn layer @@ -459,6 +460,50 @@ def strategy_cdiv( return dividend +def calculate_qparam_shape( + weight_shape: torch.Size, + quantization_args: QuantizationArgs, +) -> Tuple[int, torch.Size]: + """ + Calculate the number of groups and scale/zero_point shape for quantization. + + This centralizes the logic for determining quantization parameter shapes, + ensuring consistency with initialize_qparams and avoiding floor division bugs. + + :param weight_shape: shape of the weight tensor to be quantized + :param quantization_args: quantization configuration + :return: tuple of (num_groups, expected_shape) where: + - num_groups: number of quantization groups + - expected_shape: shape for scale/zero_point tensors (weight_shape[0], num_groups) + """ + strategy = quantization_args.strategy + + if strategy == QuantizationStrategy.TENSOR: + num_groups = 1 + expected_shape = torch.Size((1,)) + + elif strategy == QuantizationStrategy.CHANNEL: + num_groups = 1 + expected_shape = torch.Size((weight_shape[0], 1)) + + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + group_size = quantization_args.group_size + if group_size is None: + raise ValueError(f"{strategy} quantization requires group_size to be set") + + # Use strategy_cdiv for proper ceiling division and validation + num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy) + expected_shape = torch.Size((weight_shape[0], num_groups)) + + else: + raise ValueError( + f"Unsupported quantization strategy: {strategy}. " + f"Supported strategies: TENSOR, CHANNEL, GROUP, TENSOR_GROUP" + ) + + return num_groups, expected_shape + + def _get_dtype_eps(dtype: torch.dtype) -> float: if dtype == FP8_E4M3_DATA.dtype: return 0.125 diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 942fd5283..964194e82 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -214,6 +214,7 @@ def test_composability(tmp_path, sparsity_config, quantization_config): "linear.row_offsets", "linear.shape", "linear.weight_scale", + "linear.weight_zero_point", }, ) ], @@ -572,9 +573,12 @@ def test_decompress_model(model_stub, comp_stub): # equivalent to decompressing from disk assert decompressed.keys() == true_decompressed.keys() for key in decompressed.keys(): - assert ( - decompressed[key].dtype == true_decompressed[key].dtype - ), f"{key} dtypes not equal" + # Skip dtype check for weight_shape - int32/int64 are functionally equivalent + # torch.Size() works identically with both, old checkpoints use int64, new use int32 + if not key.endswith("weight_shape"): + assert ( + decompressed[key].dtype == true_decompressed[key].dtype + ), f"{key} dtypes not equal" assert torch.all( decompressed[key] == true_decompressed[key] ), f"{key} values not equal" diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index 2fb2d62d3..b33a4a234 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -89,8 +89,8 @@ def test_quant_format(strategy, group_size, sc, zp): dense_state_dict, names_to_scheme=module_name_to_scheme ) - # state_dict params should be the same, minus the zero_point if symmetric - assert len(dense_state_dict) == len(compressed_state_dict) + 1 + # state_dict params should be the same (zero_point included even for symmetric) + assert len(dense_state_dict) == len(compressed_state_dict) # check compressed to int8 assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32 diff --git a/tests/test_compressors/quantized_compressors/test_int_quant.py b/tests/test_compressors/quantized_compressors/test_int_quant.py index 627af5821..43d1efa53 100644 --- a/tests/test_compressors/quantized_compressors/test_int_quant.py +++ b/tests/test_compressors/quantized_compressors/test_int_quant.py @@ -81,11 +81,8 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp): dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) - # state_dict params should be the same, minus the zero_point if symmetric - if symmetric: - assert len(dense_state_dict) == len(compressed_state_dict) + 1 - else: - assert len(dense_state_dict) == len(compressed_state_dict) + # state_dict params should be the same (zero_point included even for symmetric) + assert len(dense_state_dict) == len(compressed_state_dict) # check compressed to int8 assert compressed_state_dict["dummy.weight"].dtype == torch.int8 From 6b38373f662d9807e546a67917aa53ff9921a476 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 9 Dec 2025 18:13:05 +0000 Subject: [PATCH 06/22] fix tests Signed-off-by: shanjiaz --- .../quantized_compressors/test_pack_quant.py | 5 ++--- .../quantized_compressors/test_packed_asym_decompression.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 05a8ea647..3c0057c06 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -88,9 +88,8 @@ def test_quant_format(shape): dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) - # compressed state_dict adds one entry for shape - # but removes the zero points since we are symmetric - assert len(dense_state_dict) == len(compressed_state_dict) + # compressed state_dict adds one entry for shape and keeps zero_point + assert len(dense_state_dict) + 1 == len(compressed_state_dict) # check compressed and packed assert compressed_state_dict["dummy.weight_packed"].dtype == torch.int32 diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index fb85bedbd..70bace936 100644 --- a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -114,6 +114,7 @@ def test_end_to_end_asymmetric_quantization( # Verify compression created zero-point parameters assert hasattr(model.layer1, "weight_zero_point") assert hasattr(model.layer2, "weight_zero_point") + # For asymmetric GROUP/CHANNEL quantization, zero_point should be packed to int32 assert model.layer1.weight_zero_point.dtype == torch.int32 assert model.layer2.weight_zero_point.dtype == torch.int32 From 9435242e881be4545fcc633c226ebf89604d4d3a Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 9 Dec 2025 19:43:38 +0000 Subject: [PATCH 07/22] cleanup Signed-off-by: shanjiaz --- .../quantized_compressors/naive_quantized.py | 12 +---- .../quantized_compressors/pack_quantized.py | 44 ++++++------------- .../test_model_compressor.py | 1 - .../quantized_compressors/test_fp8_quant.py | 4 +- .../quantized_compressors/test_int_quant.py | 7 ++- .../quantized_compressors/test_pack_quant.py | 5 ++- 6 files changed, 25 insertions(+), 48 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py index c297c6ba3..4b93a93c2 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py @@ -111,17 +111,7 @@ def compress_weight( if device is not None: quantized_weight = quantized_weight.to(device) - compressed_dict = {"weight": quantized_weight} - - # Include scale, zero_point, and g_idx if they exist - if scale is not None: - compressed_dict["weight_scale"] = scale - if zero_point is not None: - compressed_dict["weight_zero_point"] = zero_point - if g_idx is not None: - compressed_dict["weight_g_idx"] = g_idx - - return compressed_dict + return {"weight": quantized_weight} def decompress_weight( self, diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 0b874daca..873051d3a 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -132,7 +132,7 @@ def compress_weight( packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits) - weight_shape = torch.tensor(weight.shape, dtype=torch.int32) + weight_shape = torch.tensor(weight.shape) if device is not None: packed_weight = packed_weight.to(device) weight_shape = weight_shape.to(device) @@ -140,28 +140,14 @@ def compress_weight( compressed_dict["weight_shape"] = weight_shape compressed_dict["weight_packed"] = packed_weight - # Include scale if provided - if scale is not None: - compressed_dict["weight_scale"] = scale - - # Include zero_point if provided - if zero_point is not None: - if not quantization_args.symmetric and quantization_args.strategy in [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ]: - packed_zp = pack_to_int32( - zero_point, quantization_args.num_bits, packed_dim=0 - ) - compressed_dict["weight_zero_point"] = packed_zp.contiguous() - else: - # For symmetric or other strategies, include unpacked zero_point - compressed_dict["weight_zero_point"] = zero_point - - # Include g_idx if provided - if g_idx is not None: - compressed_dict["weight_g_idx"] = g_idx - + if not quantization_args.symmetric and quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + packed_zp = pack_to_int32( + zero_point, quantization_args.num_bits, packed_dim=0 + ) + compressed_dict["weight_zero_point"] = packed_zp.contiguous() return compressed_dict def decompress_weight( @@ -192,13 +178,11 @@ def decompress_weight( zero_point is not None ), "Asymmetric quantization requires zero-point values" original_zp_shape = (original_shape[0], scale.shape[-1]) - # Only unpack if it's still packed (int32) - if zero_point.dtype == torch.int32: - zero_point = unpack_from_int32( - zero_point, num_bits, original_zp_shape, packed_dim=0 - ) - # Update the compressed_data dict with the unpacked zero_point - compressed_data["weight_zero_point"] = zero_point + zero_point = unpack_from_int32( + zero_point, num_bits, original_zp_shape, packed_dim=0 + ) + # Update the compressed_data dict with the unpacked zero_point + compressed_data["weight_zero_point"] = zero_point decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 964194e82..be6b82175 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -214,7 +214,6 @@ def test_composability(tmp_path, sparsity_config, quantization_config): "linear.row_offsets", "linear.shape", "linear.weight_scale", - "linear.weight_zero_point", }, ) ], diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index b33a4a234..2fb2d62d3 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -89,8 +89,8 @@ def test_quant_format(strategy, group_size, sc, zp): dense_state_dict, names_to_scheme=module_name_to_scheme ) - # state_dict params should be the same (zero_point included even for symmetric) - assert len(dense_state_dict) == len(compressed_state_dict) + # state_dict params should be the same, minus the zero_point if symmetric + assert len(dense_state_dict) == len(compressed_state_dict) + 1 # check compressed to int8 assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32 diff --git a/tests/test_compressors/quantized_compressors/test_int_quant.py b/tests/test_compressors/quantized_compressors/test_int_quant.py index 43d1efa53..627af5821 100644 --- a/tests/test_compressors/quantized_compressors/test_int_quant.py +++ b/tests/test_compressors/quantized_compressors/test_int_quant.py @@ -81,8 +81,11 @@ def test_quant_format(strategy, symmetric, group_size, sc, zp): dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) - # state_dict params should be the same (zero_point included even for symmetric) - assert len(dense_state_dict) == len(compressed_state_dict) + # state_dict params should be the same, minus the zero_point if symmetric + if symmetric: + assert len(dense_state_dict) == len(compressed_state_dict) + 1 + else: + assert len(dense_state_dict) == len(compressed_state_dict) # check compressed to int8 assert compressed_state_dict["dummy.weight"].dtype == torch.int8 diff --git a/tests/test_compressors/quantized_compressors/test_pack_quant.py b/tests/test_compressors/quantized_compressors/test_pack_quant.py index 3c0057c06..05a8ea647 100644 --- a/tests/test_compressors/quantized_compressors/test_pack_quant.py +++ b/tests/test_compressors/quantized_compressors/test_pack_quant.py @@ -88,8 +88,9 @@ def test_quant_format(shape): dense_state_dict, names_to_scheme=quantized_modules_to_scheme ) - # compressed state_dict adds one entry for shape and keeps zero_point - assert len(dense_state_dict) + 1 == len(compressed_state_dict) + # compressed state_dict adds one entry for shape + # but removes the zero points since we are symmetric + assert len(dense_state_dict) == len(compressed_state_dict) # check compressed and packed assert compressed_state_dict["dummy.weight_packed"].dtype == torch.int32 From 65dd3793c26a4d342feb4dde870328d3c8f6ab89 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 9 Dec 2025 20:10:08 +0000 Subject: [PATCH 08/22] minimal diff Signed-off-by: shanjiaz --- .../model_compressors/test_model_compressor.py | 9 +++------ .../test_packed_asym_decompression.py | 1 - 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index be6b82175..942fd5283 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -572,12 +572,9 @@ def test_decompress_model(model_stub, comp_stub): # equivalent to decompressing from disk assert decompressed.keys() == true_decompressed.keys() for key in decompressed.keys(): - # Skip dtype check for weight_shape - int32/int64 are functionally equivalent - # torch.Size() works identically with both, old checkpoints use int64, new use int32 - if not key.endswith("weight_shape"): - assert ( - decompressed[key].dtype == true_decompressed[key].dtype - ), f"{key} dtypes not equal" + assert ( + decompressed[key].dtype == true_decompressed[key].dtype + ), f"{key} dtypes not equal" assert torch.all( decompressed[key] == true_decompressed[key] ), f"{key} values not equal" diff --git a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py index 70bace936..fb85bedbd 100644 --- a/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py +++ b/tests/test_compressors/quantized_compressors/test_packed_asym_decompression.py @@ -114,7 +114,6 @@ def test_end_to_end_asymmetric_quantization( # Verify compression created zero-point parameters assert hasattr(model.layer1, "weight_zero_point") assert hasattr(model.layer2, "weight_zero_point") - # For asymmetric GROUP/CHANNEL quantization, zero_point should be packed to int32 assert model.layer1.weight_zero_point.dtype == torch.int32 assert model.layer2.weight_zero_point.dtype == torch.int32 From effed038a0eca3c64ac90582720b55c785dd3033 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 14:59:05 +0000 Subject: [PATCH 09/22] quality Signed-off-by: shanjiaz --- nvfp4_decompress.py | 15 +++++++++++++++ .../model_compressors/model_compressor.py | 8 ++++---- .../quantization/utils/helpers.py | 8 +++++--- 3 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 nvfp4_decompress.py diff --git a/nvfp4_decompress.py b/nvfp4_decompress.py new file mode 100644 index 000000000..9d8c621e1 --- /dev/null +++ b/nvfp4_decompress.py @@ -0,0 +1,15 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +from llmcompressor.utils import dispatch_for_generation + +#MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-w4a16-asym-awq-e2e" +MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4" + +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device) +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0]))'' +print("==========================================\n\n") diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index fde3e1954..27cfec7fa 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -338,10 +338,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[ - format - ] = BaseCompressor.load_from_registry( - format, config=quantization_config + self.quantization_compressor[format] = ( + BaseCompressor.load_from_registry( + format, config=quantization_config + ) ) def get_missing_module_keys(self, model: Module) -> List[str]: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3a8535d3f..0800ea8e4 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -127,9 +127,11 @@ def calculate_qparams( # 5. Update any 0s with small values to # prevent div by 0 eps = _get_dtype_eps( - dtype=quantization_args.scale_dtype - if quantization_args.scale_dtype is not None - else scales.dtype + dtype=( + quantization_args.scale_dtype + if quantization_args.scale_dtype is not None + else scales.dtype + ) ) scales = torch.where( scales == 0, From 36c27b3868c8e6d9b385b1fbe28f6e27adcc0c3f Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 15:03:36 +0000 Subject: [PATCH 10/22] remove script Signed-off-by: shanjiaz --- nvfp4_decompress.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 nvfp4_decompress.py diff --git a/nvfp4_decompress.py b/nvfp4_decompress.py deleted file mode 100644 index 9d8c621e1..000000000 --- a/nvfp4_decompress.py +++ /dev/null @@ -1,15 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from llmcompressor.utils import dispatch_for_generation - -#MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-w4a16-asym-awq-e2e" -MODEL_ID = "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4" - -model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -print("========== SAMPLE GENERATION ==============") -dispatch_for_generation(model) -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device) -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0]))'' -print("==========================================\n\n") From c301dedccf0e2f2534527554510a09633e3b9f54 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 16:50:52 +0000 Subject: [PATCH 11/22] quality Signed-off-by: shanjiaz --- src/compressed_tensors/quantization/utils/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 0800ea8e4..3869b3405 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -476,7 +476,8 @@ def calculate_qparam_shape( :param quantization_args: quantization configuration :return: tuple of (num_groups, expected_shape) where: - num_groups: number of quantization groups - - expected_shape: shape for scale/zero_point tensors (weight_shape[0], num_groups) + - expected_shape: shape for scale/zero_point tensors + (weight_shape[0], num_groups) """ strategy = quantization_args.strategy From 2bf6e19e53dcb977c3843db71542388420d85d0f Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 17:46:08 +0000 Subject: [PATCH 12/22] minimum diff Signed-off-by: shanjiaz --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 27cfec7fa..fde3e1954 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -338,10 +338,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) def get_missing_module_keys(self, model: Module) -> List[str]: From d492543eed030a66c80168a72dbead0504d35e5c Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 19:16:28 +0000 Subject: [PATCH 13/22] added TODO Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index be607f4f0..118af185c 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -190,6 +190,10 @@ def decompress_module(self, module: Module): for name, parameter in module.named_parameters(): compressed_data[name] = parameter + # NOTE: decompress_weight may modify compressed_data dict in-place + # This is subtle but allows us to update the module's qparams with + # the unpacked values. + # TODO: Consider refactoring to return modified qparams explicitly result = self.decompress_weight( compressed_data=compressed_data, quantization_args=quantization_args ).to(device) From 63c08ac57205002857090a459e4ac0a17e76c774 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 22:02:51 +0000 Subject: [PATCH 14/22] address reviews Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 16 +++++++++++++ .../quantized_compressors/fp4_quantized.py | 24 ++----------------- .../quantization/lifecycle/initialize.py | 9 +++---- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 1700acc26..893dda3f2 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -85,6 +85,7 @@ def compress( """ uncompressed_names = list(model_state.keys()) compressed_dict = {} + compressed_prefixes = set() # compress values desc = "Compressing with quantization" @@ -119,11 +120,26 @@ def compress( device=compression_device, ) + compressed_prefixes.add(prefix) + # update state dict for key, value in compressed_values.items(): compressed_dict[prefix + key] = value.to(compression_device) else: + # Skip qparams already added by compress_weight + is_duplicate = any( + name.endswith(s) and name.removesuffix(s) in compressed_prefixes + for s in [ + "weight_scale", + "weight_zero_point", + "weight_global_scale", + "weight_g_idx", + ] + ) + if is_duplicate: + continue + # omit saving zero points for symmetric quantization if name.endswith("weight_zero_point"): module_path = name.rsplit(".", 1)[0] diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index 3db28be96..c0a6f926b 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -21,9 +21,8 @@ BaseQuantizationCompressor, ) from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization import QuantizationArgs from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize -from compressed_tensors.quantization.utils import calculate_qparam_shape from torch import Tensor @@ -73,32 +72,13 @@ def compression_param_info( :param quantization_args: quantization parameters for the weight :return: dictionary mapping compressed parameter names to shape and dtype """ - output = { + return { "weight_packed": ( torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8, ), } - # Add weight_scale and weight_global_scale for NVFP4/MXFP4 - if quantization_args is not None and quantization_args.strategy in [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.TENSOR_GROUP.value, - ]: - # Use centralized calculation for consistency and correctness - num_groups, scale_shape = calculate_qparam_shape( - weight_shape, quantization_args - ) - output["weight_scale"] = (scale_shape, quantization_args.scale_dtype) - - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP.value: - output["weight_global_scale"] = ( - torch.Size((1,)), - torch.float32, - ) - - return output - def compress_scale( self, scale: Tensor, diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 8c1b251c5..98c2f2143 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -35,7 +35,7 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.utils import strategy_cdiv +from compressed_tensors.quantization.utils import calculate_qparam_shape, strategy_cdiv from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -215,9 +215,10 @@ def initialize_qparams( if len(observed_shape) < 1: raise ValueError("Group quant requires at least 1 observed dimension") - group_size = quantization_args.group_size - num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) - expected_shape = (*observed_shape[:-1], num_groups) + # Use shared calculation to avoid floor division bugs + _, expected_shape = calculate_qparam_shape( + torch.Size(observed_shape), quantization_args + ) # initialize activation ordering if applicable if actorder == ActivationOrdering.GROUP: From f9f31059d3332482e923638e9640e4e79f1d5739 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Wed, 10 Dec 2025 22:20:25 +0000 Subject: [PATCH 15/22] fix compressed params tracking Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 893dda3f2..e2b464554 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -85,7 +85,7 @@ def compress( """ uncompressed_names = list(model_state.keys()) compressed_dict = {} - compressed_prefixes = set() + compressed_param_names = set() # compress values desc = "Compressing with quantization" @@ -120,24 +120,15 @@ def compress( device=compression_device, ) - compressed_prefixes.add(prefix) - - # update state dict + # update state dict and track which params were added for key, value in compressed_values.items(): - compressed_dict[prefix + key] = value.to(compression_device) + full_name = prefix + key + compressed_dict[full_name] = value.to(compression_device) + compressed_param_names.add(full_name) else: # Skip qparams already added by compress_weight - is_duplicate = any( - name.endswith(s) and name.removesuffix(s) in compressed_prefixes - for s in [ - "weight_scale", - "weight_zero_point", - "weight_global_scale", - "weight_g_idx", - ] - ) - if is_duplicate: + if name in compressed_param_names: continue # omit saving zero points for symmetric quantization From c6e2d4b835ea48da43110b3b28c9243db3552451 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Thu, 11 Dec 2025 23:17:39 +0000 Subject: [PATCH 16/22] use helper in initialize Signed-off-by: shanjiaz --- .../quantization/lifecycle/initialize.py | 32 +++++++++---------- .../quantization/utils/helpers.py | 15 ++++----- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 98c2f2143..45b07baa5 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -198,27 +198,27 @@ def initialize_qparams( return # 1. Infer expected scale/zp shape - if strategy == QuantizationStrategy.TENSOR: - expected_shape = (1,) - - elif strategy == QuantizationStrategy.TOKEN: + if strategy == QuantizationStrategy.TOKEN: raise ValueError("Cannot perform static token quantization") - elif strategy == QuantizationStrategy.CHANNEL: - if len(observed_shape) < 2: + elif strategy in ( + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ): + # Validate shape requirements + if strategy == QuantizationStrategy.CHANNEL and len(observed_shape) < 2: raise ValueError("Channel quant requires at least 2 observed dimensions") - - expected_shape = (observed_shape[-2], 1) - - elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - assert quantization_args.group_size is not None - if len(observed_shape) < 1: - raise ValueError("Group quant requires at least 1 observed dimension") + if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") # Use shared calculation to avoid floor division bugs - _, expected_shape = calculate_qparam_shape( - torch.Size(observed_shape), quantization_args - ) + # Note: observed_shape may contain None for dynamic dimensions (e.g., sequence length) + # but calculate_qparam_shape only accesses specific indices that are concrete + _, expected_shape = calculate_qparam_shape(observed_shape, quantization_args) # initialize activation ordering if applicable if actorder == ActivationOrdering.GROUP: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 3869b3405..e7a4c756f 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -127,11 +127,9 @@ def calculate_qparams( # 5. Update any 0s with small values to # prevent div by 0 eps = _get_dtype_eps( - dtype=( - quantization_args.scale_dtype - if quantization_args.scale_dtype is not None - else scales.dtype - ) + dtype=quantization_args.scale_dtype + if quantization_args.scale_dtype is not None + else scales.dtype ) scales = torch.where( scales == 0, @@ -483,20 +481,19 @@ def calculate_qparam_shape( if strategy == QuantizationStrategy.TENSOR: num_groups = 1 - expected_shape = torch.Size((1,)) + expected_shape = (1,) elif strategy == QuantizationStrategy.CHANNEL: num_groups = 1 - expected_shape = torch.Size((weight_shape[0], 1)) + expected_shape = (weight_shape[0], 1) elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): group_size = quantization_args.group_size if group_size is None: raise ValueError(f"{strategy} quantization requires group_size to be set") - # Use strategy_cdiv for proper ceiling division and validation num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy) - expected_shape = torch.Size((weight_shape[0], num_groups)) + expected_shape = (weight_shape[0], num_groups) else: raise ValueError( From 8cfb37507a7da807ee2dc1857d70ffb994070d00 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Thu, 11 Dec 2025 23:41:25 +0000 Subject: [PATCH 17/22] quality Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 16 +++++++++++++--- .../quantization/lifecycle/initialize.py | 4 +--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 118af185c..663c384b0 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -190,6 +190,10 @@ def decompress_module(self, module: Module): for name, parameter in module.named_parameters(): compressed_data[name] = parameter + # Save references to original parameters before decompression + original_scale = compressed_data.get("weight_scale") + original_zp = compressed_data.get("weight_zero_point") + # NOTE: decompress_weight may modify compressed_data dict in-place # This is subtle but allows us to update the module's qparams with # the unpacked values. @@ -198,9 +202,15 @@ def decompress_module(self, module: Module): compressed_data=compressed_data, quantization_args=quantization_args ).to(device) - # Update module's parameters if they were unpacked/upcast during decompression - for param_name in ["weight_zero_point", "weight_scale"]: - if param_name in compressed_data and hasattr(module, param_name): + # Update module's parameters only if they were actually modified during decompression + for param_name, original_param in [ + ("weight_scale", original_scale), + ("weight_zero_point", original_zp), + ]: + if ( + param_name in compressed_data + and compressed_data[param_name] is not original_param + ): # Delete the old parameter and register the updated one delete_offload_parameter(module, param_name) offload_device = get_offloaded_device(module) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 45b07baa5..b2d4a0ed6 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -215,9 +215,7 @@ def initialize_qparams( if len(observed_shape) < 1: raise ValueError("Group quant requires at least 1 observed dimension") - # Use shared calculation to avoid floor division bugs - # Note: observed_shape may contain None for dynamic dimensions (e.g., sequence length) - # but calculate_qparam_shape only accesses specific indices that are concrete + # Use unified helper to calculate expected shape _, expected_shape = calculate_qparam_shape(observed_shape, quantization_args) # initialize activation ordering if applicable From be6f0a88c35608c4c061b9aca3bb8302934c4988 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 12 Dec 2025 00:11:17 +0000 Subject: [PATCH 18/22] addressed reviews Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 2 +- .../compressors/quantized_compressors/fp4_quantized.py | 9 ++++++--- .../compressors/quantized_compressors/pack_quantized.py | 3 +++ .../quantization/lifecycle/initialize.py | 4 ++++ 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 663c384b0..839dfd4ee 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -202,7 +202,7 @@ def decompress_module(self, module: Module): compressed_data=compressed_data, quantization_args=quantization_args ).to(device) - # Update module's parameters only if they were actually modified during decompression + # Update module's parameters only if they were modified for param_name, original_param in [ ("weight_scale", original_scale), ("weight_zero_point", original_zp), diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index c0a6f926b..98d5f394a 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -113,9 +113,12 @@ def compress_weight( scale=scale, quantization_args=quantization_args ) - # Include global_scale if provided (for TENSOR_GROUP strategy) - if global_scale is not None: - compressed_dict["weight_global_scale"] = global_scale + if global_scale is None: + raise ValueError( + "NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). " + "Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization." + ) + compressed_dict["weight_global_scale"] = global_scale return compressed_dict diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 873051d3a..0bb25c981 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -250,6 +250,9 @@ def pack_to_int32( if packed_dim == 0: value = value.transpose(0, 1) + # Ensure contiguous memory for .view() operation + value = value.contiguous() + rows, cols = value.shape padded_cols = math.ceil(cols / pack_factor) * pack_factor pad_len = padded_cols - cols diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b2d4a0ed6..b66d479be 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -256,6 +256,10 @@ def initialize_qparams( ]: scale_dtype = torch.float16 + # Set scale_dtype in quantization_args so calculate_qparams can use it + if quantization_args.scale_dtype is None: + quantization_args.scale_dtype = scale_dtype + # 3. Initializes scale/zp for the module init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), From 84c9a5087576ea05d970e4c5cc298b059b6863c1 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 12 Dec 2025 00:46:00 +0000 Subject: [PATCH 19/22] minimum diff Signed-off-by: shanjiaz --- src/compressed_tensors/quantization/lifecycle/initialize.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b66d479be..b2d4a0ed6 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -256,10 +256,6 @@ def initialize_qparams( ]: scale_dtype = torch.float16 - # Set scale_dtype in quantization_args so calculate_qparams can use it - if quantization_args.scale_dtype is None: - quantization_args.scale_dtype = scale_dtype - # 3. Initializes scale/zp for the module init_scale = Parameter( torch.empty(expected_shape, dtype=scale_dtype, device=device), From ae98316c151a733fa68774a6a9550e667dabe09d Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 12 Dec 2025 15:44:01 +0000 Subject: [PATCH 20/22] Address some comments Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 3 ++- .../compressors/quantized_compressors/fp4_quantized.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index e2b464554..5dc6f0128 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -131,7 +131,8 @@ def compress( if name in compressed_param_names: continue - # omit saving zero points for symmetric quantization + # for symmetric quantization, omit zero_point + # manually because it wasn't handled in compress_weight if name.endswith("weight_zero_point"): module_path = name.rsplit(".", 1)[0] if ( diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index 98d5f394a..aa6f5f33b 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -118,7 +118,6 @@ def compress_weight( "NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). " "Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization." ) - compressed_dict["weight_global_scale"] = global_scale return compressed_dict From 9f8dc8a38aa7a7b395b76c61fc879c8b77b15e9b Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 12 Dec 2025 20:12:03 +0000 Subject: [PATCH 21/22] remove unneccessary helpers Signed-off-by: shanjiaz --- .../quantized_compressors/pack_quantized.py | 22 ++++----- .../quantization/lifecycle/initialize.py | 33 +++++++------- .../quantization/utils/helpers.py | 45 ------------------- 3 files changed, 28 insertions(+), 72 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 0bb25c981..2ede94cc8 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -22,7 +22,7 @@ from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize -from compressed_tensors.quantization.utils import calculate_qparam_shape, can_quantize +from compressed_tensors.quantization.utils import can_quantize from torch import Tensor @@ -64,7 +64,6 @@ def compression_param_info( """ pack_factor = 32 // quantization_args.num_bits packed_size = math.ceil(weight_shape[1] / pack_factor) - packed_size_zp = math.ceil(weight_shape[0] / pack_factor) output = { "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32), "weight_shape": (torch.Size((2,)), torch.int32), @@ -75,17 +74,20 @@ def compression_param_info( QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - # Use centralized calculation for consistency and correctness - num_groups, scale_shape = calculate_qparam_shape( - weight_shape, quantization_args + scale_cols = ( + 1 + if quantization_args.strategy == QuantizationStrategy.CHANNEL.value + else math.ceil(weight_shape[1] / quantization_args.group_size) + ) + output["weight_scale"] = ( + torch.Size((weight_shape[0], scale_cols)), + quantization_args.scale_dtype, ) - output["weight_scale"] = (scale_shape, quantization_args.scale_dtype) # Add weight_zero_point for asymmetric quantization - # Zero point has same num_groups as scale, but with packed rows if not quantization_args.symmetric: output["weight_zero_point"] = ( - torch.Size((packed_size_zp, num_groups)), + torch.Size((math.ceil(weight_shape[0] / pack_factor), scale_cols)), torch.int32, ) @@ -201,9 +203,7 @@ def compress_zp( QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - return pack_to_int32( - zero_point, quantization_args.num_bits, packed_dim=0 - ).contiguous() + return pack_to_int32(zero_point, quantization_args.num_bits, packed_dim=0) return zero_point diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b2d4a0ed6..8c1b251c5 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -35,7 +35,7 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.utils import calculate_qparam_shape, strategy_cdiv +from compressed_tensors.quantization.utils import strategy_cdiv from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -198,25 +198,26 @@ def initialize_qparams( return # 1. Infer expected scale/zp shape - if strategy == QuantizationStrategy.TOKEN: + if strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif strategy == QuantizationStrategy.TOKEN: raise ValueError("Cannot perform static token quantization") - elif strategy in ( - QuantizationStrategy.TENSOR, - QuantizationStrategy.CHANNEL, - QuantizationStrategy.GROUP, - QuantizationStrategy.TENSOR_GROUP, - ): - # Validate shape requirements - if strategy == QuantizationStrategy.CHANNEL and len(observed_shape) < 2: + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 2: raise ValueError("Channel quant requires at least 2 observed dimensions") - if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - assert quantization_args.group_size is not None - if len(observed_shape) < 1: - raise ValueError("Group quant requires at least 1 observed dimension") - # Use unified helper to calculate expected shape - _, expected_shape = calculate_qparam_shape(observed_shape, quantization_args) + expected_shape = (observed_shape[-2], 1) + + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) # initialize activation ordering if applicable if actorder == ActivationOrdering.GROUP: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index e7a4c756f..59c5f245a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -53,7 +53,6 @@ "calculate_qparams", "generate_gparam", "strategy_cdiv", - "calculate_qparam_shape", ] # target the self_attn layer @@ -460,50 +459,6 @@ def strategy_cdiv( return dividend -def calculate_qparam_shape( - weight_shape: torch.Size, - quantization_args: QuantizationArgs, -) -> Tuple[int, torch.Size]: - """ - Calculate the number of groups and scale/zero_point shape for quantization. - - This centralizes the logic for determining quantization parameter shapes, - ensuring consistency with initialize_qparams and avoiding floor division bugs. - - :param weight_shape: shape of the weight tensor to be quantized - :param quantization_args: quantization configuration - :return: tuple of (num_groups, expected_shape) where: - - num_groups: number of quantization groups - - expected_shape: shape for scale/zero_point tensors - (weight_shape[0], num_groups) - """ - strategy = quantization_args.strategy - - if strategy == QuantizationStrategy.TENSOR: - num_groups = 1 - expected_shape = (1,) - - elif strategy == QuantizationStrategy.CHANNEL: - num_groups = 1 - expected_shape = (weight_shape[0], 1) - - elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): - group_size = quantization_args.group_size - if group_size is None: - raise ValueError(f"{strategy} quantization requires group_size to be set") - - num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy) - expected_shape = (weight_shape[0], num_groups) - - else: - raise ValueError( - f"Unsupported quantization strategy: {strategy}. " - f"Supported strategies: TENSOR, CHANNEL, GROUP, TENSOR_GROUP" - ) - - return num_groups, expected_shape - - def _get_dtype_eps(dtype: torch.dtype) -> float: if dtype == FP8_E4M3_DATA.dtype: return 0.125 From cb09f979ad652843ffd04e531807eb21521b738a Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Fri, 12 Dec 2025 20:20:33 +0000 Subject: [PATCH 22/22] cleanup Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/pack_quantized.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 2ede94cc8..13c797d54 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -74,11 +74,12 @@ def compression_param_info( QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - scale_cols = ( - 1 - if quantization_args.strategy == QuantizationStrategy.CHANNEL.value - else math.ceil(weight_shape[1] / quantization_args.group_size) + shape_factor = ( + quantization_args.group_size + if quantization_args.strategy == QuantizationStrategy.GROUP.value + else weight_shape[-1] ) + scale_cols = math.ceil(weight_shape[-1] / shape_factor) output["weight_scale"] = ( torch.Size((weight_shape[0], scale_cols)), quantization_args.scale_dtype,