1414
1515
1616import logging
17- import math
18- import warnings
19- from typing import Optional
17+ from typing import Optional , Tuple
2018
2119import torch
2220from compressed_tensors .quantization import (
2321 FP8_E4M3_DATA ,
2422 ActivationOrdering ,
23+ DynamicType ,
2524 KVCacheScaleType ,
2625 QuantizationArgs ,
2726 QuantizationMetadata ,
3231from compressed_tensors .quantization .lifecycle .forward import (
3332 wrap_module_forward_quantized ,
3433)
35- from compressed_tensors .quantization .utils import is_fp4 , is_kv_cache_quant_scheme
34+ from compressed_tensors .quantization .utils import (
35+ is_fp4 ,
36+ is_kv_cache_quant_scheme ,
37+ strategy_cdiv ,
38+ )
3639from compressed_tensors .utils import (
3740 disable_hf_hook ,
3841 get_execution_device ,
4447__all__ = [
4548 "initialize_module_for_quantization" ,
4649 "is_attention_module" ,
50+ "initialize_qparams" ,
4751]
4852
4953
@@ -69,10 +73,8 @@ def initialize_module_for_quantization(
6973 :param force_zero_point: whether to force initialization of a zero point for
7074 symmetric quantization
7175 """
72- # TODO: don't initialize parameters when running decompression
7376 scheme = scheme or getattr (module , "quantization_scheme" , None )
7477 if scheme is None :
75- # no scheme passed and layer not targeted for quantization - skip
7678 return
7779
7880 QuantizationMetadata .clear_all_qparams (module )
@@ -82,38 +84,52 @@ def initialize_module_for_quantization(
8284 _initialize_attn_scales (module )
8385
8486 else :
87+ if not isinstance (module , torch .nn .Linear ):
88+ _LOGGER .warning (f"Attempting to quantize module of type { type (module )} " )
89+
90+ # use weight to determine observed shapes and dtype
91+ if hasattr (module , "weight" ):
92+ weight = module .weight
93+ assert isinstance (weight , torch .Tensor )
94+ else :
95+ # Note that a weight is required for both weight and activation
96+ # quantization in order to know the dtype of activation scales
97+ _LOGGER .warning (
98+ f"module type { type (module )} targeted for quantization but "
99+ f"has no attribute weight, skipping quantization for { type (module )} "
100+ )
101+ return
102+
85103 if scheme .input_activations is not None :
86- _initialize_scale_zero_point (
104+ initialize_qparams (
87105 module ,
88106 "input" ,
89107 scheme .input_activations ,
108+ observed_shape = weight .shape [- 1 :],
109+ observed_dtype = weight .dtype ,
90110 force_zero_point = force_zero_point ,
91111 )
92112
93113 if scheme .weights is not None :
94- if hasattr (module , "weight" ):
95- weight_shape = None
96- if isinstance (module , torch .nn .Linear ):
97- weight_shape = module .weight .shape
98- _initialize_scale_zero_point (
99- module ,
100- "weight" ,
101- scheme .weights ,
102- weight_shape = weight_shape ,
103- force_zero_point = force_zero_point ,
104- )
105- else :
106- _LOGGER .warning (
107- f"module type { type (module )} targeted for weight quantization but "
108- "has no attribute weight, skipping weight quantization "
109- f"for { type (module )} "
110- )
111-
112- if scheme .output_activations is not None :
113- if not is_kv_cache_quant_scheme (scheme ):
114- _initialize_scale_zero_point (
115- module , "output" , scheme .output_activations
116- )
114+ initialize_qparams (
115+ module ,
116+ "weight" ,
117+ scheme .weights ,
118+ observed_shape = weight .shape ,
119+ observed_dtype = weight .dtype ,
120+ force_zero_point = force_zero_point ,
121+ )
122+
123+ output_is_kv_cache = is_kv_cache_quant_scheme (scheme )
124+ if scheme .output_activations is not None and not output_is_kv_cache :
125+ initialize_qparams (
126+ module ,
127+ "output" ,
128+ scheme .output_activations ,
129+ observed_shape = weight .shape [:- 1 ],
130+ observed_dtype = weight .dtype ,
131+ force_zero_point = force_zero_point ,
132+ )
117133
118134 module .quantization_scheme = scheme
119135 module .quantization_status = QuantizationStatus .INITIALIZED
@@ -132,22 +148,40 @@ def is_attention_module(module: Module):
132148 )
133149
134150
135- def _initialize_scale_zero_point (
151+ def initialize_qparams (
136152 module : Module ,
137153 base_name : str ,
138154 quantization_args : QuantizationArgs ,
139- weight_shape : Optional [torch .Size ] = None ,
155+ observed_shape : Tuple [int ],
156+ observed_dtype : torch .dtype ,
140157 force_zero_point : bool = True ,
141158):
142- if quantization_args .dynamic is True :
143- return
159+ """
160+ Initialize quantization parameters for a given basename according to the passed
161+ quantization args. The shape and dtype of the observed weight/activation must also
162+ be provided.
163+
164+ Scales will always be initialized. Global scales are initialized depending on args.
165+ Zero points will be initialized if not symmetric or if `force_zero_point` is True.
166+
167+ :param module: module to register qparams to
168+ :param base_name: base name of qparams, for example "input", "weight", "k", "v"
169+ :param quantization_args: arguments for quantization
170+ :param observed_shape: last (right-most) known dimensions of the observed weight/act
171+ :param observed_dtype: dtype of the observed weight/actt
172+ :param force_zero_point: force the zero_point parameter to be initialized
173+ """
174+ strategy = quantization_args .strategy
175+ dynamic = quantization_args .dynamic
176+ actorder = quantization_args .actorder
177+ device = get_execution_device (module ) # avoid performing intialization ops on cpu
144178
145- # initialize on execution device to avoid performing quantized ops on cpu
146- device = get_execution_device (module )
179+ # Skip all intialization for fully dynamic quantization
180+ if dynamic is True :
181+ return
147182
148- # 1. Create global_scales for tensor_group - generates
149- # a per tensor scale
150- if quantization_args .strategy == QuantizationStrategy .TENSOR_GROUP :
183+ # 0. Create global scale for tensor-group quantization
184+ if strategy == QuantizationStrategy .TENSOR_GROUP :
151185 init_global_scale = Parameter (
152186 torch .empty (1 , dtype = torch .float32 , device = device ),
153187 requires_grad = False ,
@@ -156,56 +190,55 @@ def _initialize_scale_zero_point(
156190 module , f"{ base_name } _global_scale" , init_global_scale
157191 )
158192
159- # 2. Infer expected scale/zero point shape
160- if quantization_args .strategy == QuantizationStrategy .TOKEN :
193+ # Skip scale/zp initialization for locally dynamic quantization
194+ if dynamic == DynamicType .LOCAL :
195+ return
196+
197+ # 1. Infer expected scale/zp shape
198+ if strategy == QuantizationStrategy .TENSOR :
199+ expected_shape = (1 ,)
200+
201+ elif strategy == QuantizationStrategy .TOKEN :
161202 expected_shape = (1 , 1 )
203+
204+ elif strategy == QuantizationStrategy .CHANNEL :
205+ if len (observed_shape ) < 2 :
206+ raise ValueError ("Channel quant requires at least 2 observed dimensions" )
207+
208+ expected_shape = (observed_shape [- 2 ], 1 )
209+
210+ elif strategy in (QuantizationStrategy .GROUP , QuantizationStrategy .TENSOR_GROUP ):
211+ assert quantization_args .group_size is not None
212+ if len (observed_shape ) < 1 :
213+ raise ValueError ("Group quant requires at least 1 observed dimension" )
214+
215+ group_size = quantization_args .group_size
216+ num_groups = strategy_cdiv (observed_shape [- 1 ], group_size , strategy )
217+ expected_shape = (* observed_shape [:- 1 ], num_groups )
218+
219+ # initialize activation ordering if applicable
220+ if actorder == ActivationOrdering .GROUP :
221+ init_g_idx = Parameter (
222+ torch .full ((observed_shape [- 1 ],), - 1 , device = device , dtype = torch .int ),
223+ requires_grad = False ,
224+ )
225+ register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
226+
227+ elif strategy == QuantizationStrategy .BLOCK :
228+ assert quantization_args .block_structure is not None
229+ if len (observed_shape ) < 2 :
230+ raise ValueError ("Block quant requires at least 2 observed dimensions" )
231+
232+ block_structure = quantization_args .block_structure
233+ num_rows = strategy_cdiv (observed_shape [- 2 ], block_structure [- 2 ], strategy )
234+ num_cols = strategy_cdiv (observed_shape [- 1 ], block_structure [- 1 ], strategy )
235+ expected_shape = (num_rows , num_cols )
236+
162237 else :
163- expected_shape = 1
164-
165- if base_name == "weight" and weight_shape is not None :
166- if quantization_args .strategy == QuantizationStrategy .CHANNEL :
167- # (output_channels, 1) - only for weights
168- expected_shape = (weight_shape [0 ], 1 )
169- elif quantization_args .strategy in (
170- QuantizationStrategy .TENSOR_GROUP ,
171- QuantizationStrategy .GROUP ,
172- ):
173- # GROUP/TENSOR_GROUP for both weights and activations
174- num_groups = math .ceil (weight_shape [1 ] / quantization_args .group_size )
175- expected_shape = (weight_shape [0 ], max (num_groups , 1 ))
176- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
177- # For block quantization, scale shape should match number of blocks - only
178- # for weights
179- if quantization_args .block_structure is None :
180- raise ValueError (
181- "Block quantization requires block_structure to be specified"
182- )
183- block_height , block_width = quantization_args .block_structure
184- rows , cols = weight_shape [- 2 ], weight_shape [- 1 ]
185- num_rows_blocks = math .ceil (rows / block_height )
186- num_cols_blocks = math .ceil (cols / block_width )
187-
188- # Warn if dimensions don't divide evenly
189- if rows % block_height != 0 or cols % block_width != 0 :
190- warnings .warn (
191- f"Block quantization: tensor shape { weight_shape } does not divide"
192- f"evenly by block structure { quantization_args .block_structure } . "
193- f"Some blocks will be incomplete which may affect quantization"
194- "quality." ,
195- UserWarning ,
196- )
197-
198- expected_shape = (num_rows_blocks , num_cols_blocks )
199- elif quantization_args .strategy == QuantizationStrategy .BLOCK :
200- warnings .warn (
201- f"BLOCK quantization not supported for { base_name } activations. "
202- f"Falling back to tensor-level quantization." ,
203- UserWarning ,
204- )
205- expected_shape = 1
238+ assert False , f"Unknown strategy { strategy } "
206239
207- # 3 . Identify quantization scale and zp dtype
208- scale_dtype = module . weight . dtype
240+ # 2 . Identify quantization scale and zp dtype
241+ scale_dtype = observed_dtype
209242
210243 if is_fp4 (quantization_args = quantization_args ):
211244 scale_dtype = zp_dtype = FP8_E4M3_DATA .dtype
@@ -221,14 +254,12 @@ def _initialize_scale_zero_point(
221254 scale_dtype = torch .bfloat16
222255 zp_dtype = quantization_args .pytorch_dtype ()
223256
224- # 4. Initializes empty scale, zero point, and g_idx parameters for the module
225- # do not init scales for quantzation_args.dynamic == DynamicType.local
226- if not quantization_args .dynamic :
227- init_scale = Parameter (
228- torch .empty (expected_shape , dtype = scale_dtype , device = device ),
229- requires_grad = False ,
230- )
231- register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
257+ # 3. Initializes scale/zp for the module
258+ init_scale = Parameter (
259+ torch .empty (expected_shape , dtype = scale_dtype , device = device ),
260+ requires_grad = False ,
261+ )
262+ register_offload_parameter (module , f"{ base_name } _scale" , init_scale )
232263
233264 if force_zero_point or not quantization_args .symmetric :
234265 init_zero_point = Parameter (
@@ -237,16 +268,6 @@ def _initialize_scale_zero_point(
237268 )
238269 register_offload_parameter (module , f"{ base_name } _zero_point" , init_zero_point )
239270
240- # only grouped activation ordering has g_idx
241- if quantization_args .actorder == ActivationOrdering .GROUP :
242- g_idx_shape = (weight_shape [1 ],)
243- g_idx_dtype = torch .int
244- init_g_idx = Parameter (
245- torch .full (g_idx_shape , - 1 , device = device , dtype = g_idx_dtype ),
246- requires_grad = False ,
247- )
248- register_offload_parameter (module , f"{ base_name } _g_idx" , init_g_idx )
249-
250271
251272def _initialize_attn_scales (module : Module ) -> None :
252273 """Initlaize k_scale, v_scale for self_attn"""
0 commit comments