Skip to content

Commit 9f8dc8a

Browse files
committed
remove unneccessary helpers
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent ae98316 commit 9f8dc8a

File tree

3 files changed

+28
-72
lines changed

3 files changed

+28
-72
lines changed

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from compressed_tensors.config import CompressionFormat
2323
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
2424
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25-
from compressed_tensors.quantization.utils import calculate_qparam_shape, can_quantize
25+
from compressed_tensors.quantization.utils import can_quantize
2626
from torch import Tensor
2727

2828

@@ -64,7 +64,6 @@ def compression_param_info(
6464
"""
6565
pack_factor = 32 // quantization_args.num_bits
6666
packed_size = math.ceil(weight_shape[1] / pack_factor)
67-
packed_size_zp = math.ceil(weight_shape[0] / pack_factor)
6867
output = {
6968
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
7069
"weight_shape": (torch.Size((2,)), torch.int32),
@@ -75,17 +74,20 @@ def compression_param_info(
7574
QuantizationStrategy.GROUP.value,
7675
QuantizationStrategy.CHANNEL.value,
7776
]:
78-
# Use centralized calculation for consistency and correctness
79-
num_groups, scale_shape = calculate_qparam_shape(
80-
weight_shape, quantization_args
77+
scale_cols = (
78+
1
79+
if quantization_args.strategy == QuantizationStrategy.CHANNEL.value
80+
else math.ceil(weight_shape[1] / quantization_args.group_size)
81+
)
82+
output["weight_scale"] = (
83+
torch.Size((weight_shape[0], scale_cols)),
84+
quantization_args.scale_dtype,
8185
)
82-
output["weight_scale"] = (scale_shape, quantization_args.scale_dtype)
8386

8487
# Add weight_zero_point for asymmetric quantization
85-
# Zero point has same num_groups as scale, but with packed rows
8688
if not quantization_args.symmetric:
8789
output["weight_zero_point"] = (
88-
torch.Size((packed_size_zp, num_groups)),
90+
torch.Size((math.ceil(weight_shape[0] / pack_factor), scale_cols)),
8991
torch.int32,
9092
)
9193

@@ -201,9 +203,7 @@ def compress_zp(
201203
QuantizationStrategy.GROUP.value,
202204
QuantizationStrategy.CHANNEL.value,
203205
]:
204-
return pack_to_int32(
205-
zero_point, quantization_args.num_bits, packed_dim=0
206-
).contiguous()
206+
return pack_to_int32(zero_point, quantization_args.num_bits, packed_dim=0)
207207
return zero_point
208208

209209

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from compressed_tensors.quantization.lifecycle.forward import (
3636
wrap_module_forward_quantized,
3737
)
38-
from compressed_tensors.quantization.utils import calculate_qparam_shape, strategy_cdiv
38+
from compressed_tensors.quantization.utils import strategy_cdiv
3939
from compressed_tensors.utils import (
4040
disable_hf_hook,
4141
get_execution_device,
@@ -198,25 +198,26 @@ def initialize_qparams(
198198
return
199199

200200
# 1. Infer expected scale/zp shape
201-
if strategy == QuantizationStrategy.TOKEN:
201+
if strategy == QuantizationStrategy.TENSOR:
202+
expected_shape = (1,)
203+
204+
elif strategy == QuantizationStrategy.TOKEN:
202205
raise ValueError("Cannot perform static token quantization")
203206

204-
elif strategy in (
205-
QuantizationStrategy.TENSOR,
206-
QuantizationStrategy.CHANNEL,
207-
QuantizationStrategy.GROUP,
208-
QuantizationStrategy.TENSOR_GROUP,
209-
):
210-
# Validate shape requirements
211-
if strategy == QuantizationStrategy.CHANNEL and len(observed_shape) < 2:
207+
elif strategy == QuantizationStrategy.CHANNEL:
208+
if len(observed_shape) < 2:
212209
raise ValueError("Channel quant requires at least 2 observed dimensions")
213-
if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
214-
assert quantization_args.group_size is not None
215-
if len(observed_shape) < 1:
216-
raise ValueError("Group quant requires at least 1 observed dimension")
217210

218-
# Use unified helper to calculate expected shape
219-
_, expected_shape = calculate_qparam_shape(observed_shape, quantization_args)
211+
expected_shape = (observed_shape[-2], 1)
212+
213+
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
214+
assert quantization_args.group_size is not None
215+
if len(observed_shape) < 1:
216+
raise ValueError("Group quant requires at least 1 observed dimension")
217+
218+
group_size = quantization_args.group_size
219+
num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
220+
expected_shape = (*observed_shape[:-1], num_groups)
220221

221222
# initialize activation ordering if applicable
222223
if actorder == ActivationOrdering.GROUP:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
"calculate_qparams",
5454
"generate_gparam",
5555
"strategy_cdiv",
56-
"calculate_qparam_shape",
5756
]
5857

5958
# target the self_attn layer
@@ -460,50 +459,6 @@ def strategy_cdiv(
460459
return dividend
461460

462461

463-
def calculate_qparam_shape(
464-
weight_shape: torch.Size,
465-
quantization_args: QuantizationArgs,
466-
) -> Tuple[int, torch.Size]:
467-
"""
468-
Calculate the number of groups and scale/zero_point shape for quantization.
469-
470-
This centralizes the logic for determining quantization parameter shapes,
471-
ensuring consistency with initialize_qparams and avoiding floor division bugs.
472-
473-
:param weight_shape: shape of the weight tensor to be quantized
474-
:param quantization_args: quantization configuration
475-
:return: tuple of (num_groups, expected_shape) where:
476-
- num_groups: number of quantization groups
477-
- expected_shape: shape for scale/zero_point tensors
478-
(weight_shape[0], num_groups)
479-
"""
480-
strategy = quantization_args.strategy
481-
482-
if strategy == QuantizationStrategy.TENSOR:
483-
num_groups = 1
484-
expected_shape = (1,)
485-
486-
elif strategy == QuantizationStrategy.CHANNEL:
487-
num_groups = 1
488-
expected_shape = (weight_shape[0], 1)
489-
490-
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
491-
group_size = quantization_args.group_size
492-
if group_size is None:
493-
raise ValueError(f"{strategy} quantization requires group_size to be set")
494-
495-
num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy)
496-
expected_shape = (weight_shape[0], num_groups)
497-
498-
else:
499-
raise ValueError(
500-
f"Unsupported quantization strategy: {strategy}. "
501-
f"Supported strategies: TENSOR, CHANNEL, GROUP, TENSOR_GROUP"
502-
)
503-
504-
return num_groups, expected_shape
505-
506-
507462
def _get_dtype_eps(dtype: torch.dtype) -> float:
508463
if dtype == FP8_E4M3_DATA.dtype:
509464
return 0.125

0 commit comments

Comments
 (0)