@@ -198,27 +198,27 @@ def initialize_qparams(
198198 return
199199
200200 # 1. Infer expected scale/zp shape
201- if strategy == QuantizationStrategy .TENSOR :
202- expected_shape = (1 ,)
203-
204- elif strategy == QuantizationStrategy .TOKEN :
201+ if strategy == QuantizationStrategy .TOKEN :
205202 raise ValueError ("Cannot perform static token quantization" )
206203
207- elif strategy == QuantizationStrategy .CHANNEL :
208- if len (observed_shape ) < 2 :
204+ elif strategy in (
205+ QuantizationStrategy .TENSOR ,
206+ QuantizationStrategy .CHANNEL ,
207+ QuantizationStrategy .GROUP ,
208+ QuantizationStrategy .TENSOR_GROUP ,
209+ ):
210+ # Validate shape requirements
211+ if strategy == QuantizationStrategy .CHANNEL and len (observed_shape ) < 2 :
209212 raise ValueError ("Channel quant requires at least 2 observed dimensions" )
210-
211- expected_shape = (observed_shape [- 2 ], 1 )
212-
213- elif strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
214- assert quantization_args .group_size is not None
215- if len (observed_shape ) < 1 :
216- raise ValueError ("Group quant requires at least 1 observed dimension" )
213+ if strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
214+ assert quantization_args .group_size is not None
215+ if len (observed_shape ) < 1 :
216+ raise ValueError ("Group quant requires at least 1 observed dimension" )
217217
218218 # Use shared calculation to avoid floor division bugs
219- _ , expected_shape = calculate_qparam_shape (
220- torch . Size ( observed_shape ), quantization_args
221- )
219+ # Note: observed_shape may contain None for dynamic dimensions (e.g., sequence length)
220+ # but calculate_qparam_shape only accesses specific indices that are concrete
221+ _ , expected_shape = calculate_qparam_shape ( observed_shape , quantization_args )
222222
223223 # initialize activation ordering if applicable
224224 if actorder == ActivationOrdering .GROUP :
0 commit comments