Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import has_offloaded_params
from compressed_tensors.utils.offload import (
delete_offload_parameter,
get_offloaded_device,
register_offload_parameter,
)
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -185,10 +190,37 @@ def decompress_module(self, module: Module):
for name, parameter in module.named_parameters():
compressed_data[name] = parameter

return self.decompress_weight(
# Save references to original parameters before decompression
original_scale = compressed_data.get("weight_scale")
original_zp = compressed_data.get("weight_zero_point")

# NOTE: decompress_weight may modify compressed_data dict in-place
# This is subtle but allows us to update the module's qparams with
# the unpacked values.
# TODO: Consider refactoring to return modified qparams explicitly
result = self.decompress_weight(
compressed_data=compressed_data, quantization_args=quantization_args
).to(device)

# Update module's parameters only if they were modified
for param_name, original_param in [
("weight_scale", original_scale),
("weight_zero_point", original_zp),
]:
if (
param_name in compressed_data
and compressed_data[param_name] is not original_param
):
# Delete the old parameter and register the updated one
delete_offload_parameter(module, param_name)
offload_device = get_offloaded_device(module)
param = torch.nn.Parameter(
compressed_data[param_name], requires_grad=False
)
register_offload_parameter(module, param_name, param, offload_device)

return result

def decompress_weight(
self, compressed_data: Dict[str, Tensor], **kwargs
) -> torch.Tensor:
Expand Down
63 changes: 24 additions & 39 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
Expand Down Expand Up @@ -85,6 +85,7 @@ def compress(
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
compressed_param_names = set()

# compress values
desc = "Compressing with quantization"
Expand Down Expand Up @@ -119,54 +120,38 @@ def compress(
device=compression_device,
)

# update state dict
# update state dict and track which params were added
for key, value in compressed_values.items():
compressed_dict[prefix + key] = value.to(compression_device)
full_name = prefix + key
compressed_dict[full_name] = value.to(compression_device)
compressed_param_names.add(full_name)

else:
# omit saving zero points for symmetric or packed quantization
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
# Skip qparams already added by compress_weight
if name in compressed_param_names:
continue

if name.endswith("weight_scale") and self._skip_scale():
continue
# for symmetric quantization, omit zero_point
# manually because it wasn't handled in compress_weight
if name.endswith("weight_zero_point"):
module_path = name.rsplit(".", 1)[0]
if (
module_path in names_to_scheme
and names_to_scheme[module_path].weights.symmetric
):
continue
# Call compress_zp if available (for PackedQuantizationCompressor)
if module_path in names_to_scheme and hasattr(self, "compress_zp"):
value = self.compress_zp(
value, names_to_scheme[module_path].weights
)
if value is None:
continue

compressed_dict[name] = value.to(compression_device)

return compressed_dict

def _skip_scale(self):
from compressed_tensors.compressors import NVFP4PackedCompressor

return isinstance(self, NVFP4PackedCompressor)

def _skip_zp(
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
) -> bool:
from compressed_tensors.compressors import PackedQuantizationCompressor

module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
scheme = names_to_scheme[module_name]

if zp_name == "weight_zero_point":
args = scheme.weights
if zp_name == "input_zero_point":
args = scheme.input_activations
if zp_name == "output_zero_point":
args = scheme.output_activations

symmetric = args.symmetric
packable_strategies = [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]
packed = (
isinstance(self, PackedQuantizationCompressor)
and args.strategy in packable_strategies
)

return symmetric or packed

def decompress(
self,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def compression_param_names(self) -> Tuple[str]:
return (
"weight_packed",
"weight_scale",
"weight_zero_point",
"weight_global_scale",
)

Expand All @@ -73,13 +72,12 @@ def compression_param_info(
:param quantization_args: quantization parameters for the weight
:return: dictionary mapping compressed parameter names to shape and dtype
"""
output = {
return {
"weight_packed": (
torch.Size((weight_shape[0], weight_shape[1] // 2)),
torch.uint8,
),
}
return output

def compress_scale(
self,
Expand Down Expand Up @@ -114,6 +112,13 @@ def compress_weight(
compressed_dict["weight_scale"] = self.compress_scale(
scale=scale, quantization_args=quantization_args
)

if global_scale is None:
raise ValueError(
"NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). "
"Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization."
)

return compressed_dict

def decompress_weight(
Expand All @@ -127,6 +132,12 @@ def decompress_weight(
m, n = weight.shape
# TODO: use a user provided dequant dtype
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)

# cast scale dtype to match unpacked dtype for dequantization
if scale.dtype != unpacked.dtype:
scale = scale.to(unpacked.dtype)
compressed_data["weight_scale"] = scale

decompressed_weight = dequantize(
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,34 @@ def compression_param_info(
"""
pack_factor = 32 // quantization_args.num_bits
packed_size = math.ceil(weight_shape[1] / pack_factor)
packed_size_zp = math.ceil(weight_shape[0] / pack_factor)
output = {
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
"weight_shape": (torch.Size((2,)), torch.int32),
}
if not quantization_args.symmetric and quantization_args.strategy in [

# Add weight_scale - always needed for quantization
if quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
zp_factor = (
shape_factor = (
quantization_args.group_size
if quantization_args.strategy == QuantizationStrategy.GROUP.value
else weight_shape[-1]
)

output["weight_zero_point"] = (
torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)),
torch.int32,
scale_cols = math.ceil(weight_shape[-1] / shape_factor)
output["weight_scale"] = (
torch.Size((weight_shape[0], scale_cols)),
quantization_args.scale_dtype,
)

# Add weight_zero_point for asymmetric quantization
if not quantization_args.symmetric:
output["weight_zero_point"] = (
torch.Size((math.ceil(weight_shape[0] / pack_factor), scale_cols)),
torch.int32,
)

return output

def compress_weight(
Expand Down Expand Up @@ -175,13 +184,29 @@ def decompress_weight(
zero_point = unpack_from_int32(
zero_point, num_bits, original_zp_shape, packed_dim=0
)
# Update the compressed_data dict with the unpacked zero_point
compressed_data["weight_zero_point"] = zero_point

decompressed_weight = dequantize(
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
)

return decompressed_weight

def compress_zp(
self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None
) -> Optional[Tensor]:
if zero_point is None or quantization_args.symmetric:
return None
if zero_point.dtype == torch.int32:
return zero_point
if quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
return pack_to_int32(zero_point, quantization_args.num_bits, packed_dim=0)
return zero_point


def pack_to_int32(
value: torch.Tensor,
Expand Down Expand Up @@ -226,6 +251,9 @@ def pack_to_int32(
if packed_dim == 0:
value = value.transpose(0, 1)

# Ensure contiguous memory for .view() operation
value = value.contiguous()

rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
Expand Down