Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import has_offloaded_params
from compressed_tensors.utils.offload import (
delete_offload_parameter,
get_offloaded_device,
register_offload_parameter,
)
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -185,10 +190,37 @@ def decompress_module(self, module: Module):
for name, parameter in module.named_parameters():
compressed_data[name] = parameter

return self.decompress_weight(
# Save references to original parameters before decompression
original_scale = compressed_data.get("weight_scale")
original_zp = compressed_data.get("weight_zero_point")

# NOTE: decompress_weight may modify compressed_data dict in-place
# This is subtle but allows us to update the module's qparams with
# the unpacked values.
# TODO: Consider refactoring to return modified qparams explicitly
result = self.decompress_weight(
compressed_data=compressed_data, quantization_args=quantization_args
).to(device)

# Update module's parameters only if they were modified
for param_name, original_param in [
("weight_scale", original_scale),
("weight_zero_point", original_zp),
]:
if (
param_name in compressed_data
and compressed_data[param_name] is not original_param
):
# Delete the old parameter and register the updated one
delete_offload_parameter(module, param_name)
offload_device = get_offloaded_device(module)
param = torch.nn.Parameter(
compressed_data[param_name], requires_grad=False
)
register_offload_parameter(module, param_name, param, offload_device)

return result

def decompress_weight(
self, compressed_data: Dict[str, Tensor], **kwargs
) -> torch.Tensor:
Expand Down
63 changes: 24 additions & 39 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
Expand Down Expand Up @@ -85,6 +85,7 @@ def compress(
"""
uncompressed_names = list(model_state.keys())
compressed_dict = {}
compressed_param_names = set()

# compress values
desc = "Compressing with quantization"
Expand Down Expand Up @@ -119,54 +120,38 @@ def compress(
device=compression_device,
)

# update state dict
# update state dict and track which params were added
for key, value in compressed_values.items():
compressed_dict[prefix + key] = value.to(compression_device)
full_name = prefix + key
compressed_dict[full_name] = value.to(compression_device)
compressed_param_names.add(full_name)

else:
# omit saving zero points for symmetric or packed quantization
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
# Skip qparams already added by compress_weight
if name in compressed_param_names:
continue

if name.endswith("weight_scale") and self._skip_scale():
continue
# for symmetric quantization, omit zero_point
# manually because it wasn't handled in compress_weight
if name.endswith("weight_zero_point"):
module_path = name.rsplit(".", 1)[0]
if (
module_path in names_to_scheme
and names_to_scheme[module_path].weights.symmetric
):
continue
# Call compress_zp if available (for PackedQuantizationCompressor)
if module_path in names_to_scheme and hasattr(self, "compress_zp"):
value = self.compress_zp(
value, names_to_scheme[module_path].weights
)
if value is None:
continue

compressed_dict[name] = value.to(compression_device)

return compressed_dict

def _skip_scale(self):
from compressed_tensors.compressors import NVFP4PackedCompressor

return isinstance(self, NVFP4PackedCompressor)

def _skip_zp(
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
) -> bool:
from compressed_tensors.compressors import PackedQuantizationCompressor

module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
scheme = names_to_scheme[module_name]

if zp_name == "weight_zero_point":
args = scheme.weights
if zp_name == "input_zero_point":
args = scheme.input_activations
if zp_name == "output_zero_point":
args = scheme.output_activations

symmetric = args.symmetric
packable_strategies = [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]
packed = (
isinstance(self, PackedQuantizationCompressor)
and args.strategy in packable_strategies
)

return symmetric or packed

def decompress(
self,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def compression_param_names(self) -> Tuple[str]:
return (
"weight_packed",
"weight_scale",
"weight_zero_point",
"weight_global_scale",
)

Expand All @@ -73,13 +72,12 @@ def compression_param_info(
:param quantization_args: quantization parameters for the weight
:return: dictionary mapping compressed parameter names to shape and dtype
"""
output = {
return {
"weight_packed": (
torch.Size((weight_shape[0], weight_shape[1] // 2)),
torch.uint8,
),
}
return output

def compress_scale(
self,
Expand Down Expand Up @@ -114,6 +112,13 @@ def compress_weight(
compressed_dict["weight_scale"] = self.compress_scale(
scale=scale, quantization_args=quantization_args
)

if global_scale is None:
raise ValueError(
"NVFP4 quantization requires global_scale (TENSOR_GROUP strategy). "
"Use TENSOR_GROUP strategy instead of GROUP for FP4 quantization."
)

return compressed_dict

def decompress_weight(
Expand All @@ -127,6 +132,12 @@ def decompress_weight(
m, n = weight.shape
# TODO: use a user provided dequant dtype
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)

# cast scale dtype to match unpacked dtype for dequantization
if scale.dtype != unpacked.dtype:
scale = scale.to(unpacked.dtype)
compressed_data["weight_scale"] = scale

decompressed_weight = dequantize(
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
from compressed_tensors.quantization.utils import can_quantize
from compressed_tensors.quantization.utils import calculate_qparam_shape, can_quantize
from torch import Tensor


Expand Down Expand Up @@ -69,20 +69,26 @@ def compression_param_info(
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
"weight_shape": (torch.Size((2,)), torch.int32),
}
if not quantization_args.symmetric and quantization_args.strategy in [

# Add weight_scale - always needed for quantization
if quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
zp_factor = (
quantization_args.group_size
if quantization_args.strategy == QuantizationStrategy.GROUP.value
else weight_shape[-1]
# Use centralized calculation for consistency and correctness
num_groups, scale_shape = calculate_qparam_shape(
weight_shape, quantization_args
)
output["weight_scale"] = (scale_shape, quantization_args.scale_dtype)

# Add weight_zero_point for asymmetric quantization
# Zero point has same num_groups as scale, but with packed rows
if not quantization_args.symmetric:
output["weight_zero_point"] = (
torch.Size((packed_size_zp, num_groups)),
torch.int32,
)

output["weight_zero_point"] = (
torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)),
torch.int32,
)
return output

def compress_weight(
Expand Down Expand Up @@ -175,13 +181,31 @@ def decompress_weight(
zero_point = unpack_from_int32(
zero_point, num_bits, original_zp_shape, packed_dim=0
)
# Update the compressed_data dict with the unpacked zero_point
compressed_data["weight_zero_point"] = zero_point

decompressed_weight = dequantize(
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
)

return decompressed_weight

def compress_zp(
self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None
) -> Optional[Tensor]:
if zero_point is None or quantization_args.symmetric:
return None
if zero_point.dtype == torch.int32:
return zero_point
if quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
return pack_to_int32(
zero_point, quantization_args.num_bits, packed_dim=0
).contiguous()
return zero_point


def pack_to_int32(
value: torch.Tensor,
Expand Down Expand Up @@ -226,6 +250,9 @@ def pack_to_int32(
if packed_dim == 0:
value = value.transpose(0, 1)

# Ensure contiguous memory for .view() operation
value = value.contiguous()

rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols
Expand Down
33 changes: 16 additions & 17 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from compressed_tensors.quantization.lifecycle.forward import (
wrap_module_forward_quantized,
)
from compressed_tensors.quantization.utils import strategy_cdiv
from compressed_tensors.quantization.utils import calculate_qparam_shape, strategy_cdiv
from compressed_tensors.utils import (
disable_hf_hook,
get_execution_device,
Expand Down Expand Up @@ -198,26 +198,25 @@ def initialize_qparams(
return

# 1. Infer expected scale/zp shape
if strategy == QuantizationStrategy.TENSOR:
expected_shape = (1,)

elif strategy == QuantizationStrategy.TOKEN:
if strategy == QuantizationStrategy.TOKEN:
raise ValueError("Cannot perform static token quantization")

elif strategy == QuantizationStrategy.CHANNEL:
if len(observed_shape) < 2:
elif strategy in (
QuantizationStrategy.TENSOR,
QuantizationStrategy.CHANNEL,
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
):
# Validate shape requirements
if strategy == QuantizationStrategy.CHANNEL and len(observed_shape) < 2:
raise ValueError("Channel quant requires at least 2 observed dimensions")
if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
assert quantization_args.group_size is not None
if len(observed_shape) < 1:
raise ValueError("Group quant requires at least 1 observed dimension")

expected_shape = (observed_shape[-2], 1)

elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
assert quantization_args.group_size is not None
if len(observed_shape) < 1:
raise ValueError("Group quant requires at least 1 observed dimension")

group_size = quantization_args.group_size
num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
expected_shape = (*observed_shape[:-1], num_groups)
# Use unified helper to calculate expected shape
_, expected_shape = calculate_qparam_shape(observed_shape, quantization_args)

# initialize activation ordering if applicable
if actorder == ActivationOrdering.GROUP:
Expand Down
45 changes: 45 additions & 0 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"calculate_qparams",
"generate_gparam",
"strategy_cdiv",
"calculate_qparam_shape",
]

# target the self_attn layer
Expand Down Expand Up @@ -459,6 +460,50 @@ def strategy_cdiv(
return dividend


def calculate_qparam_shape(
weight_shape: torch.Size,
quantization_args: QuantizationArgs,
) -> Tuple[int, torch.Size]:
"""
Calculate the number of groups and scale/zero_point shape for quantization.

This centralizes the logic for determining quantization parameter shapes,
ensuring consistency with initialize_qparams and avoiding floor division bugs.

:param weight_shape: shape of the weight tensor to be quantized
:param quantization_args: quantization configuration
:return: tuple of (num_groups, expected_shape) where:
- num_groups: number of quantization groups
- expected_shape: shape for scale/zero_point tensors
(weight_shape[0], num_groups)
"""
strategy = quantization_args.strategy

if strategy == QuantizationStrategy.TENSOR:
num_groups = 1
expected_shape = (1,)

elif strategy == QuantizationStrategy.CHANNEL:
num_groups = 1
expected_shape = (weight_shape[0], 1)

elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
group_size = quantization_args.group_size
if group_size is None:
raise ValueError(f"{strategy} quantization requires group_size to be set")

num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy)
expected_shape = (weight_shape[0], num_groups)

else:
raise ValueError(
f"Unsupported quantization strategy: {strategy}. "
f"Supported strategies: TENSOR, CHANNEL, GROUP, TENSOR_GROUP"
)

return num_groups, expected_shape


def _get_dtype_eps(dtype: torch.dtype) -> float:
if dtype == FP8_E4M3_DATA.dtype:
return 0.125
Expand Down