diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 3a8a97eb..839dfd4e 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -20,6 +20,11 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig from compressed_tensors.registry import RegistryMixin from compressed_tensors.utils import has_offloaded_params +from compressed_tensors.utils.offload import ( + delete_offload_parameter, + get_offloaded_device, + register_offload_parameter, +) from torch import Tensor from torch.nn import Module @@ -185,10 +190,37 @@ def decompress_module(self, module: Module): for name, parameter in module.named_parameters(): compressed_data[name] = parameter - return self.decompress_weight( + # Save references to original parameters before decompression + original_scale = compressed_data.get("weight_scale") + original_zp = compressed_data.get("weight_zero_point") + + # NOTE: decompress_weight may modify compressed_data dict in-place + # This is subtle but allows us to update the module's qparams with + # the unpacked values. + # TODO: Consider refactoring to return modified qparams explicitly + result = self.decompress_weight( compressed_data=compressed_data, quantization_args=quantization_args ).to(device) + # Update module's parameters only if they were modified + for param_name, original_param in [ + ("weight_scale", original_scale), + ("weight_zero_point", original_zp), + ]: + if ( + param_name in compressed_data + and compressed_data[param_name] is not original_param + ): + # Delete the old parameter and register the updated one + delete_offload_parameter(module, param_name) + offload_device = get_offloaded_device(module) + param = torch.nn.Parameter( + compressed_data[param_name], requires_grad=False + ) + register_offload_parameter(module, param_name, param, offload_device) + + return result + def decompress_weight( self, compressed_data: Dict[str, Tensor], **kwargs ) -> torch.Tensor: diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 19f6c9c0..5dc6f012 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -18,7 +18,7 @@ import torch from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy +from compressed_tensors.quantization import QuantizationScheme from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, get_nested_weight_mappings, @@ -85,6 +85,7 @@ def compress( """ uncompressed_names = list(model_state.keys()) compressed_dict = {} + compressed_param_names = set() # compress values desc = "Compressing with quantization" @@ -119,54 +120,38 @@ def compress( device=compression_device, ) - # update state dict + # update state dict and track which params were added for key, value in compressed_values.items(): - compressed_dict[prefix + key] = value.to(compression_device) + full_name = prefix + key + compressed_dict[full_name] = value.to(compression_device) + compressed_param_names.add(full_name) else: - # omit saving zero points for symmetric or packed quantization - if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme): + # Skip qparams already added by compress_weight + if name in compressed_param_names: continue - if name.endswith("weight_scale") and self._skip_scale(): - continue + # for symmetric quantization, omit zero_point + # manually because it wasn't handled in compress_weight + if name.endswith("weight_zero_point"): + module_path = name.rsplit(".", 1)[0] + if ( + module_path in names_to_scheme + and names_to_scheme[module_path].weights.symmetric + ): + continue + # Call compress_zp if available (for PackedQuantizationCompressor) + if module_path in names_to_scheme and hasattr(self, "compress_zp"): + value = self.compress_zp( + value, names_to_scheme[module_path].weights + ) + if value is None: + continue compressed_dict[name] = value.to(compression_device) return compressed_dict - def _skip_scale(self): - from compressed_tensors.compressors import NVFP4PackedCompressor - - return isinstance(self, NVFP4PackedCompressor) - - def _skip_zp( - self, name: str, names_to_scheme: Dict[str, QuantizationScheme] - ) -> bool: - from compressed_tensors.compressors import PackedQuantizationCompressor - - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) - scheme = names_to_scheme[module_name] - - if zp_name == "weight_zero_point": - args = scheme.weights - if zp_name == "input_zero_point": - args = scheme.input_activations - if zp_name == "output_zero_point": - args = scheme.output_activations - - symmetric = args.symmetric - packable_strategies = [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ] - packed = ( - isinstance(self, PackedQuantizationCompressor) - and args.strategy in packable_strategies - ) - - return symmetric or packed - def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index dd3c2a46..aa6f5f33 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -56,7 +56,6 @@ def compression_param_names(self) -> Tuple[str]: return ( "weight_packed", "weight_scale", - "weight_zero_point", "weight_global_scale", ) @@ -73,13 +72,12 @@ def compression_param_info( :param quantization_args: quantization parameters for the weight :return: dictionary mapping compressed parameter names to shape and dtype """ - output = { + return { "weight_packed": ( torch.Size((weight_shape[0], weight_shape[1] // 2)), torch.uint8, ), } - return output def compress_scale( self, @@ -114,6 +112,13 @@ def compress_weight( compressed_dict["weight_scale"] = self.compress_scale( scale=scale, quantization_args=quantization_args ) + + if global_scale is None: + raise ValueError( + "NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). " + "Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization." + ) + return compressed_dict def decompress_weight( @@ -127,6 +132,12 @@ def decompress_weight( m, n = weight.shape # TODO: use a user provided dequant dtype unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + + # cast scale dtype to match unpacked dtype for dequantization + if scale.dtype != unpacked.dtype: + scale = scale.to(unpacked.dtype) + compressed_data["weight_scale"] = scale + decompressed_weight = dequantize( x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 07da50c7..13c797d5 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -64,25 +64,34 @@ def compression_param_info( """ pack_factor = 32 // quantization_args.num_bits packed_size = math.ceil(weight_shape[1] / pack_factor) - packed_size_zp = math.ceil(weight_shape[0] / pack_factor) output = { "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32), "weight_shape": (torch.Size((2,)), torch.int32), } - if not quantization_args.symmetric and quantization_args.strategy in [ + + # Add weight_scale - always needed for quantization + if quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, ]: - zp_factor = ( + shape_factor = ( quantization_args.group_size if quantization_args.strategy == QuantizationStrategy.GROUP.value else weight_shape[-1] ) - - output["weight_zero_point"] = ( - torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)), - torch.int32, + scale_cols = math.ceil(weight_shape[-1] / shape_factor) + output["weight_scale"] = ( + torch.Size((weight_shape[0], scale_cols)), + quantization_args.scale_dtype, ) + + # Add weight_zero_point for asymmetric quantization + if not quantization_args.symmetric: + output["weight_zero_point"] = ( + torch.Size((math.ceil(weight_shape[0] / pack_factor), scale_cols)), + torch.int32, + ) + return output def compress_weight( @@ -175,6 +184,8 @@ def decompress_weight( zero_point = unpack_from_int32( zero_point, num_bits, original_zp_shape, packed_dim=0 ) + # Update the compressed_data dict with the unpacked zero_point + compressed_data["weight_zero_point"] = zero_point decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx @@ -182,6 +193,20 @@ def decompress_weight( return decompressed_weight + def compress_zp( + self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None + ) -> Optional[Tensor]: + if zero_point is None or quantization_args.symmetric: + return None + if zero_point.dtype == torch.int32: + return zero_point + if quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + return pack_to_int32(zero_point, quantization_args.num_bits, packed_dim=0) + return zero_point + def pack_to_int32( value: torch.Tensor, @@ -226,6 +251,9 @@ def pack_to_int32( if packed_dim == 0: value = value.transpose(0, 1) + # Ensure contiguous memory for .view() operation + value = value.contiguous() + rows, cols = value.shape padded_cols = math.ceil(cols / pack_factor) * pack_factor pad_len = padded_cols - cols