Skip to content

Commit c6e2d4b

Browse files
committed
use helper in initialize
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
1 parent f9f3105 commit c6e2d4b

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -198,27 +198,27 @@ def initialize_qparams(
198198
return
199199

200200
# 1. Infer expected scale/zp shape
201-
if strategy == QuantizationStrategy.TENSOR:
202-
expected_shape = (1,)
203-
204-
elif strategy == QuantizationStrategy.TOKEN:
201+
if strategy == QuantizationStrategy.TOKEN:
205202
raise ValueError("Cannot perform static token quantization")
206203

207-
elif strategy == QuantizationStrategy.CHANNEL:
208-
if len(observed_shape) < 2:
204+
elif strategy in (
205+
QuantizationStrategy.TENSOR,
206+
QuantizationStrategy.CHANNEL,
207+
QuantizationStrategy.GROUP,
208+
QuantizationStrategy.TENSOR_GROUP,
209+
):
210+
# Validate shape requirements
211+
if strategy == QuantizationStrategy.CHANNEL and len(observed_shape) < 2:
209212
raise ValueError("Channel quant requires at least 2 observed dimensions")
210-
211-
expected_shape = (observed_shape[-2], 1)
212-
213-
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
214-
assert quantization_args.group_size is not None
215-
if len(observed_shape) < 1:
216-
raise ValueError("Group quant requires at least 1 observed dimension")
213+
if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
214+
assert quantization_args.group_size is not None
215+
if len(observed_shape) < 1:
216+
raise ValueError("Group quant requires at least 1 observed dimension")
217217

218218
# Use shared calculation to avoid floor division bugs
219-
_, expected_shape = calculate_qparam_shape(
220-
torch.Size(observed_shape), quantization_args
221-
)
219+
# Note: observed_shape may contain None for dynamic dimensions (e.g., sequence length)
220+
# but calculate_qparam_shape only accesses specific indices that are concrete
221+
_, expected_shape = calculate_qparam_shape(observed_shape, quantization_args)
222222

223223
# initialize activation ordering if applicable
224224
if actorder == ActivationOrdering.GROUP:

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ def calculate_qparams(
127127
# 5. Update any 0s with small values to
128128
# prevent div by 0
129129
eps = _get_dtype_eps(
130-
dtype=(
131-
quantization_args.scale_dtype
132-
if quantization_args.scale_dtype is not None
133-
else scales.dtype
134-
)
130+
dtype=quantization_args.scale_dtype
131+
if quantization_args.scale_dtype is not None
132+
else scales.dtype
135133
)
136134
scales = torch.where(
137135
scales == 0,
@@ -483,20 +481,19 @@ def calculate_qparam_shape(
483481

484482
if strategy == QuantizationStrategy.TENSOR:
485483
num_groups = 1
486-
expected_shape = torch.Size((1,))
484+
expected_shape = (1,)
487485

488486
elif strategy == QuantizationStrategy.CHANNEL:
489487
num_groups = 1
490-
expected_shape = torch.Size((weight_shape[0], 1))
488+
expected_shape = (weight_shape[0], 1)
491489

492490
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
493491
group_size = quantization_args.group_size
494492
if group_size is None:
495493
raise ValueError(f"{strategy} quantization requires group_size to be set")
496494

497-
# Use strategy_cdiv for proper ceiling division and validation
498495
num_groups = strategy_cdiv(weight_shape[-1], group_size, strategy)
499-
expected_shape = torch.Size((weight_shape[0], num_groups))
496+
expected_shape = (weight_shape[0], num_groups)
500497

501498
else:
502499
raise ValueError(

0 commit comments

Comments
 (0)