From 121145592d1ed2b15b2582490d81b44bea4955a2 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 28 Oct 2025 20:56:55 +0000 Subject: [PATCH 01/31] update --- .../quantization/lifecycle/initialize.py | 25 +++++----- .../quantization/quant_args.py | 4 +- .../quantization/quant_scheme.py | 7 +++ .../quantization/utils/helpers.py | 33 ++++--------- .../test_quantization/lifecycle/test_apply.py | 47 ++++++++++++++++--- .../lifecycle/test_initialize.py | 16 ++++++- .../lifecycle/test_static_lifecycle.py | 24 ++-------- 7 files changed, 90 insertions(+), 66 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4bd75a2b3..4e751ae34 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -36,7 +36,7 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv +from compressed_tensors.quantization.utils import strategy_cdiv from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, @@ -249,32 +249,31 @@ def initialize_qparams( assert False, f"Unknown strategy {strategy}" # 2. Identify quantization scale and zp dtype - scale_dtype = observed_dtype - - if is_fp4(quantization_args=quantization_args): - scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype - else: - # TODO: consider erroring out in the future as if the dtype if not one of these, - # there is likely bug - if scale_dtype not in [ + if quantization_args.scale_dtype is None: + if observed_dtype not in [ torch.float16, torch.bfloat16, torch.float32, torch.float64, ]: - scale_dtype = torch.bfloat16 - zp_dtype = quantization_args.pytorch_dtype() + observed_dtype = torch.float16 + quantization_args.scale_dtype = observed_dtype + + if quantization_args.zp_dtype is None: + quantization_args.zp_dtype = quantization_args.pytorch_dtype() # 3. Initializes scale/zp for the module init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), + torch.empty(expected_shape, dtype=quantization_args.scale_dtype, device=device), requires_grad=False, ) register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( - torch.zeros(expected_shape, device=device, dtype=zp_dtype), + torch.zeros( + expected_shape, device=device, dtype=quantization_args.zp_dtype + ), requires_grad=False, ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 8bfbc41f2..504fd3f28 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -174,6 +174,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True): block_structure: Optional[List[int]] = None dynamic: Union[DynamicType, bool] = False actorder: Union[ActivationOrdering, bool, None] = None + scale_dtype: Optional[torch.dtype] = None + zp_dtype: Optional[torch.dtype] = None observer: Optional[str] = Field( default=None, description=( @@ -378,7 +380,7 @@ def pytorch_dtype(self) -> torch.dtype: def get_observer(self) -> str: return self.observer - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) def round_to_quantized_type( diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 79db8d28a..d31e133e8 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -18,6 +18,7 @@ from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( + FP8_E4M3_DATA, DynamicType, QuantizationArgs, QuantizationStrategy, @@ -160,6 +161,8 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=16, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ) ) @@ -173,6 +176,8 @@ def is_preset_scheme(name: str) -> bool: dynamic=False, group_size=16, observer="static_minmax", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), input_activations=QuantizationArgs( num_bits=4, @@ -182,6 +187,8 @@ def is_preset_scheme(name: str) -> bool: dynamic=DynamicType.LOCAL, group_size=16, observer="static_minmax", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), ) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 0099b088b..1bbe573b2 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -46,7 +46,6 @@ "calculate_range", "calculate_qparams", "generate_gparam", - "is_fp4", "strategy_cdiv", ] @@ -57,13 +56,6 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -def is_fp4(quantization_args: QuantizationArgs): - return ( - quantization_args.num_bits == 4 - and quantization_args.type == QuantizationType.FLOAT - ) - - def calculate_qparams( min_vals: Tensor, max_vals: Tensor, @@ -92,22 +84,20 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min - if is_fp4(quantization_args=quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype - else: - zp_dtype = quantization_args.pytorch_dtype() + zp_dtype = quantization_args.zp_dtype if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + scales = max_val_pos / (float(bit_range) / 2) - if is_fp4(quantization_args=quantization_args) and global_scale is not None: + if global_scale is not None: # Conditionally scale the generated local scale by a global_scale - scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) - scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min) - scales = scales.to(FP8_E4M3_DATA.dtype) - - else: - scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp( + scales, + max=torch.finfo(quantization_args.scale_dtype).max, + min=torch.finfo(quantization_args.scale_dtype).min, + ) + scales = scales.to(quantization_args.scale_dtype) # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped if scales.dtype == FP8_E4M3_DATA.dtype: @@ -123,11 +113,6 @@ def calculate_qparams( zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: - if is_fp4(quantization_args=quantization_args): - raise NotImplementedError( - "Asymmetric Quantization is not supported for FP4" - ) - scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = bit_min - (min_vals / scales) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 8d57ca403..caf679781 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -22,6 +22,7 @@ from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, + FP8_E4M3_DATA, QuantizationArgs, QuantizationConfig, QuantizationScheme, @@ -153,7 +154,11 @@ def test_apply_quantization_config_tinyllama(): "linear": QuantizationScheme( targets=["Linear"], input_activations=QuantizationArgs( - num_bits=8, type="float", strategy="tensor" + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, ), ) } @@ -163,7 +168,11 @@ def test_apply_quantization_config_tinyllama(): "linear": QuantizationScheme( targets=["Linear"], input_activations=QuantizationArgs( - num_bits=8, type="float", strategy="tensor" + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, ), ) }, @@ -176,7 +185,11 @@ def test_apply_quantization_config_tinyllama(): QuantizationConfig( config_groups={}, kv_cache_scheme=QuantizationArgs( - num_bits=8, type="float", strategy="tensor" + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, ), ), QuantizationConfig( @@ -184,12 +197,20 @@ def test_apply_quantization_config_tinyllama(): "attention": QuantizationScheme( targets=["LlamaAttention"], input_activations=QuantizationArgs( - num_bits=8, type="float", strategy="tensor" + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, ), ) }, kv_cache_scheme=QuantizationArgs( - num_bits=8, type="float", strategy="tensor" + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, ), ), ], @@ -448,7 +469,13 @@ def test_apply_kv_cache(): with init_empty_weights(): model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") - args = QuantizationArgs(num_bits=8, type="float", strategy="tensor") + args = QuantizationArgs( + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, + ) config = QuantizationConfig(config_groups={}, kv_cache_scheme=args) apply_quantization_config(model, config) @@ -468,7 +495,13 @@ def test_apply_attention(): scheme = QuantizationScheme( targets=["LlamaAttention"], - input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"), + input_activations=QuantizationArgs( + num_bits=8, + type="float", + strategy="tensor", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=torch.float, + ), ) config = QuantizationConfig(config_groups={"attention": scheme}) diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index f6db66971..cbeed83e7 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -156,13 +156,23 @@ def test_initialize_module_for_quantization_offloaded( ), ( QuantizationArgs( - strategy="tensor_group", group_size=16, type="float", num_bits=4 + strategy="tensor_group", + group_size=16, + type="float", + num_bits=4, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), None, ), ( QuantizationArgs( - strategy="tensor_group", group_size=16, type="float", num_bits=4 + strategy="tensor_group", + group_size=16, + type="float", + num_bits=4, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), QuantizationArgs( strategy="tensor_group", @@ -170,6 +180,8 @@ def test_initialize_module_for_quantization_offloaded( type="float", num_bits=4, dynamic="local", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), ), ( diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index 36a857d13..bf007eef5 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -15,6 +15,7 @@ import pytest import torch from compressed_tensors.quantization import ( + FP8_E4M3_DATA, QuantizationScheme, forward_quantize, initialize_module_for_quantization, @@ -96,6 +97,8 @@ symmetric=True, strategy="tensor_group", # requires float4 group_size=3, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), @@ -176,25 +179,6 @@ def test_static_weight_quantization( @pytest.mark.parametrize( "args,exp_min_val,exp_max_val,exp_quant,exp_loss", [ - ( - QuantizationArgs( - num_bits=4, - type="int", - symmetric=True, - strategy="tensor", - ), - torch.tensor([0.0]), - torch.tensor([11.0]), - torch.tensor( - [ - [ - [0.0000, 1.4688, 1.4688, 2.9375, 4.4062, 4.4062], - [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], - ] - ] - ), - 0.2, - ), # static token is not supported # channel is not supported # group is not supported @@ -206,6 +190,8 @@ def test_static_weight_quantization( strategy="tensor_group", dynamic="local", group_size=3, + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, ), None, None, From 41aa0fc856687552ab52d8b4f9fb50f841216511 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 28 Oct 2025 20:59:48 +0000 Subject: [PATCH 02/31] add back test --- .../lifecycle/test_static_lifecycle.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_quantization/lifecycle/test_static_lifecycle.py b/tests/test_quantization/lifecycle/test_static_lifecycle.py index bf007eef5..fa38c8ed1 100644 --- a/tests/test_quantization/lifecycle/test_static_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_static_lifecycle.py @@ -179,6 +179,25 @@ def test_static_weight_quantization( @pytest.mark.parametrize( "args,exp_min_val,exp_max_val,exp_quant,exp_loss", [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [0.0000, 1.4688, 1.4688, 2.9375, 4.4062, 4.4062], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + ] + ] + ), + 0.2, + ), # static token is not supported # channel is not supported # group is not supported From de9f16a14131abf6e969022d768bffb662f528da Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 28 Oct 2025 21:02:43 +0000 Subject: [PATCH 03/31] update --- src/compressed_tensors/quantization/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 1bbe573b2..561b3ef42 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -92,7 +92,7 @@ def calculate_qparams( if global_scale is not None: # Conditionally scale the generated local scale by a global_scale - scales = torch.clamp( + scales = global_scale * torch.clamp( scales, max=torch.finfo(quantization_args.scale_dtype).max, min=torch.finfo(quantization_args.scale_dtype).min, From c02000de659d7019e33d7a9639a7b9464090b99b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 28 Oct 2025 21:46:30 +0000 Subject: [PATCH 04/31] update --- .../model_compressors/model_compressor.py | 1 + .../quantization/quant_config.py | 24 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index fde3e1954..9f141728c 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -790,6 +790,7 @@ def update_config(self, save_directory: str): config_data = {} # serialize configs into json + breakpoint() qconfig_data = ( self.quantization_config.model_dump(exclude=["quant_method"]) if self.quantization_config is not None diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index bed6078fa..08e2ac601 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import torch from collections import defaultdict from enum import Enum from typing import Annotated, Any, Dict, List, Optional, Set, Union @@ -279,5 +279,27 @@ def requires_calibration_data(self): return False + def model_dump(self, *args, **kwargs): + # Call the parent dump first + data = super().model_dump(*args, **kwargs) + + # Convert any torch.dtype to string + schemes = ["config_groups", "kv_cache_scheme"] + for scheme in schemes: + if data.get(scheme) is not None: + for _, v in data[scheme].items(): + weight = v.get("weights") + input = v.get("input_activations") + output = v.get("output_activations") + + args = [weight, input, output] + for arg in args: + for key, value in arg.items(): + if isinstance(value, torch.dtype): + data[key] = str(value).replace("torch.", "") + + breakpoint() + return data + # TODO set `extra="forbid"` when upstream transformers is compatible model_config = ConfigDict(extra="ignore") From fbccd400e80c975ff37ffd1c0dff9ee69cd9a5b5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 29 Oct 2025 19:43:55 +0000 Subject: [PATCH 05/31] fix serialization --- .../model_compressors/model_compressor.py | 1 - .../quantization/quant_config.py | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 9f141728c..fde3e1954 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -790,7 +790,6 @@ def update_config(self, save_directory: str): config_data = {} # serialize configs into json - breakpoint() qconfig_data = ( self.quantization_config.model_dump(exclude=["quant_method"]) if self.quantization_config is not None diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 08e2ac601..309fbe9dd 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch from collections import defaultdict from enum import Enum from typing import Annotated, Any, Dict, List, Optional, Set, Union +import torch from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs from compressed_tensors.quantization.quant_scheme import ( @@ -283,7 +283,17 @@ def model_dump(self, *args, **kwargs): # Call the parent dump first data = super().model_dump(*args, **kwargs) - # Convert any torch.dtype to string + def _convert_dtypes_in_dict(d): + for k, v in d.items(): + if isinstance(v, torch.dtype): + if k == "zp_dtype" and d.get("symmetric"): + d[k] = None + else: + d[k] = str(v).replace("torch.", "") + elif isinstance(v, dict): + _convert_dtypes_in_dict(v) + return d + schemes = ["config_groups", "kv_cache_scheme"] for scheme in schemes: if data.get(scheme) is not None: @@ -294,11 +304,8 @@ def model_dump(self, *args, **kwargs): args = [weight, input, output] for arg in args: - for key, value in arg.items(): - if isinstance(value, torch.dtype): - data[key] = str(value).replace("torch.", "") - - breakpoint() + if arg is not None: + _convert_dtypes_in_dict(arg) return data # TODO set `extra="forbid"` when upstream transformers is compatible From 2a2f2a325b79b0e746d29db15c0db9c459b4dfe5 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 29 Oct 2025 20:04:54 +0000 Subject: [PATCH 06/31] fix condition --- src/compressed_tensors/quantization/quant_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 309fbe9dd..08dfd0957 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -286,7 +286,9 @@ def model_dump(self, *args, **kwargs): def _convert_dtypes_in_dict(d): for k, v in d.items(): if isinstance(v, torch.dtype): - if k == "zp_dtype" and d.get("symmetric"): + if (k == "zp_dtype" and d.get("symmetric")) or ( + k == "scale_dtype" and d.get("dynamic") in (True, "local") + ): d[k] = None else: d[k] = str(v).replace("torch.", "") From cbd6d6614377c077d3ba6730d8a343d75239801a Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 29 Oct 2025 20:21:08 +0000 Subject: [PATCH 07/31] update --- src/compressed_tensors/quantization/lifecycle/initialize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4e751ae34..cb3c1254b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -24,7 +24,6 @@ QuantizedKVCache, ) from compressed_tensors.quantization import ( - FP8_E4M3_DATA, ActivationOrdering, DynamicType, QuantizationArgs, From 6fca61f6b1d9bc0c4ae33f495544a4b07e356102 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 29 Oct 2025 20:37:05 +0000 Subject: [PATCH 08/31] update --- .../quantization/utils/helpers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 561b3ef42..5a6812ea9 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -92,12 +92,14 @@ def calculate_qparams( if global_scale is not None: # Conditionally scale the generated local scale by a global_scale - scales = global_scale * torch.clamp( - scales, - max=torch.finfo(quantization_args.scale_dtype).max, - min=torch.finfo(quantization_args.scale_dtype).min, - ) - scales = scales.to(quantization_args.scale_dtype) + scales = global_scale * scales + if quantization_args.scale_dtype is not None: + scales = torch.clamp( + scales, + max=torch.finfo(quantization_args.scale_dtype).max, + min=torch.finfo(quantization_args.scale_dtype).min, + ) + scales = scales.to(quantization_args.scale_dtype) # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped if scales.dtype == FP8_E4M3_DATA.dtype: From e53bf78374140f45a0c404b7b12359964493bd0d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 29 Oct 2025 20:40:27 +0000 Subject: [PATCH 09/31] update --- .../quantization/utils/helpers.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5a6812ea9..d22d642db 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -93,13 +93,14 @@ def calculate_qparams( if global_scale is not None: # Conditionally scale the generated local scale by a global_scale scales = global_scale * scales - if quantization_args.scale_dtype is not None: - scales = torch.clamp( - scales, - max=torch.finfo(quantization_args.scale_dtype).max, - min=torch.finfo(quantization_args.scale_dtype).min, - ) - scales = scales.to(quantization_args.scale_dtype) + + if quantization_args.scale_dtype is not None: + scales = torch.clamp( + scales, + max=torch.finfo(quantization_args.scale_dtype).max, + min=torch.finfo(quantization_args.scale_dtype).min, + ) + scales = scales.to(quantization_args.scale_dtype) # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped if scales.dtype == FP8_E4M3_DATA.dtype: From dec2b2c3d04615de69264a35df8a470ca63bf0b2 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 31 Oct 2025 02:36:58 +0000 Subject: [PATCH 10/31] update --- .../quantization/quant_config.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 08dfd0957..987344f63 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -296,18 +296,23 @@ def _convert_dtypes_in_dict(d): _convert_dtypes_in_dict(v) return d - schemes = ["config_groups", "kv_cache_scheme"] - for scheme in schemes: - if data.get(scheme) is not None: - for _, v in data[scheme].items(): - weight = v.get("weights") - input = v.get("input_activations") - output = v.get("output_activations") - - args = [weight, input, output] - for arg in args: - if arg is not None: - _convert_dtypes_in_dict(arg) + scheme = "config_groups" + if data.get(scheme): + for _, v in data[scheme].items(): + weight = v.get("weights") + input = v.get("input_activations") + output = v.get("output_activations") + + args = [weight, input, output] + for arg in args: + if arg is not None: + _convert_dtypes_in_dict(arg) + + scheme = "kv_cache_scheme" + kv_cache_data = data.get(scheme) + if kv_cache_data: + _convert_dtypes_in_dict(kv_cache_data) + return data # TODO set `extra="forbid"` when upstream transformers is compatible From 8b7181c37b559bc9accf26c3d0a3f26af6825a40 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 13:27:39 -0500 Subject: [PATCH 11/31] update --- src/compressed_tensors/quantization/quant_args.py | 7 ++++--- src/compressed_tensors/quantization/quant_config.py | 11 ++++------- src/compressed_tensors/quantization/utils/helpers.py | 11 ++++++++--- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 504fd3f28..b8179ccee 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -19,6 +19,7 @@ import torch from compressed_tensors.utils import Aliasable from compressed_tensors.utils.helpers import deprecated +from compressed_tensors.utils.type import TorchDtype from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -174,8 +175,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True): block_structure: Optional[List[int]] = None dynamic: Union[DynamicType, bool] = False actorder: Union[ActivationOrdering, bool, None] = None - scale_dtype: Optional[torch.dtype] = None - zp_dtype: Optional[torch.dtype] = None + scale_dtype: Optional[TorchDtype] = None + zp_dtype: Optional[TorchDtype] = None observer: Optional[str] = Field( default=None, description=( @@ -380,7 +381,7 @@ def pytorch_dtype(self) -> torch.dtype: def get_observer(self) -> str: return self.observer - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + model_config = ConfigDict(extra="forbid") def round_to_quantized_type( diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 987344f63..14fe4da43 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -285,13 +285,10 @@ def model_dump(self, *args, **kwargs): def _convert_dtypes_in_dict(d): for k, v in d.items(): - if isinstance(v, torch.dtype): - if (k == "zp_dtype" and d.get("symmetric")) or ( - k == "scale_dtype" and d.get("dynamic") in (True, "local") - ): - d[k] = None - else: - d[k] = str(v).replace("torch.", "") + if (k == "zp_dtype" and d.get("symmetric")) or ( + k == "scale_dtype" and d.get("dynamic") in (True, "local") + ): + d[k] = None elif isinstance(v, dict): _convert_dtypes_in_dict(v) return d diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index d22d642db..512728b1d 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -84,8 +84,6 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min - zp_dtype = quantization_args.zp_dtype - if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) @@ -116,6 +114,13 @@ def calculate_qparams( zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: + if ( + quantization_args.num_bits == 4 + and quantization_args.type == QuantizationType.FLOAT + ): + raise NotImplementedError( + "Asymmetric Quantization is not supported for FP4" + ) scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = bit_min - (min_vals / scales) @@ -125,7 +130,7 @@ def calculate_qparams( # if casting to int, use round instead of truncate if quantization_args.type == QuantizationType.INT: zero_points = torch.round(zero_points) - zero_points = zero_points.to(zp_dtype) + zero_points = zero_points.to(quantization_args.zp_dtype) if scales.ndim == 0: scales = scales.reshape(1) From 9bd90401bdf6a5bda82c8595b559eefe596f99ff Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 13:36:03 -0500 Subject: [PATCH 12/31] remove torch --- src/compressed_tensors/quantization/quant_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 14fe4da43..f856f9783 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Annotated, Any, Dict, List, Optional, Set, Union -import torch from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs from compressed_tensors.quantization.quant_scheme import ( From ecb7d7f735850b52659ef57bc0d3b15fa5486679 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 13:51:34 -0500 Subject: [PATCH 13/31] update --- .../quantization/lifecycle/initialize.py | 16 +--------------- .../quantization/quant_args.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index cb3c1254b..00369de26 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -247,21 +247,7 @@ def initialize_qparams( else: assert False, f"Unknown strategy {strategy}" - # 2. Identify quantization scale and zp dtype - if quantization_args.scale_dtype is None: - if observed_dtype not in [ - torch.float16, - torch.bfloat16, - torch.float32, - torch.float64, - ]: - observed_dtype = torch.float16 - quantization_args.scale_dtype = observed_dtype - - if quantization_args.zp_dtype is None: - quantization_args.zp_dtype = quantization_args.pytorch_dtype() - - # 3. Initializes scale/zp for the module + # 2. Initializes scale/zp for the module init_scale = Parameter( torch.empty(expected_shape, dtype=quantization_args.scale_dtype, device=device), requires_grad=False, diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index b8179ccee..ef81406cb 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -269,6 +269,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": dynamic = model.dynamic observer = model.observer dynamic = model.dynamic + scale_dtype = model.scale_dtype + zp_dtype = model.zp_dtype # infer strategy if strategy is None: @@ -356,9 +358,18 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": # default to minmax for non-dynamic cases observer = "minmax" + if zp_dtype is None: + zp_dtype = model.pytorch_dtype() + + # 2. Identify quantization scale and zp dtype + if scale_dtype is None: + scale_dtype = torch.bfloat16 + # write back modified values model.strategy = strategy model.observer = observer + model.zp_dtype = zp_dtype + model.scale_dtype = scale_dtype return model def pytorch_dtype(self) -> torch.dtype: From 933c624e4d2ad08d19bcf14917aa670477181767 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 14:05:24 -0500 Subject: [PATCH 14/31] update --- src/compressed_tensors/quantization/quant_args.py | 1 - src/compressed_tensors/quantization/utils/helpers.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index ef81406cb..83bbfff9f 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -361,7 +361,6 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": if zp_dtype is None: zp_dtype = model.pytorch_dtype() - # 2. Identify quantization scale and zp dtype if scale_dtype is None: scale_dtype = torch.bfloat16 diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 512728b1d..4b4508d35 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -93,10 +93,17 @@ def calculate_qparams( scales = global_scale * scales if quantization_args.scale_dtype is not None: + if torch.is_floating_point( + torch.empty((), dtype=quantization_args.scale_dtype) + ): + info = torch.finfo(quantization_args.scale_dtype) + else: + info = torch.iinfo(quantization_args.scale_dtype) + scales = torch.clamp( scales, - max=torch.finfo(quantization_args.scale_dtype).max, - min=torch.finfo(quantization_args.scale_dtype).min, + min=info.min, + max=info.max, ) scales = scales.to(quantization_args.scale_dtype) From ee742c0fae43ff1f7c08cd25a9c0256516d7bfc8 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 14:32:36 -0500 Subject: [PATCH 15/31] update tests --- .../test_quantization/lifecycle/test_dynamic_lifecycle.py | 2 +- tests/test_quantization/lifecycle/test_forward.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index 3ac91e851..ca7a5c47c 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -79,7 +79,7 @@ def _test_layer_dynamic_quantization_status( def get_tinyllama_model(): return AutoModelForCausalLM.from_pretrained( "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", - torch_dtype="auto", + torch_dtype=torch.bfloat16, ) diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index f3321cd40..8d9b48a83 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -64,8 +64,12 @@ def test_forward_quantize( num_bits = 8 quantization_scheme = create_quantization_scheme( targets=["*"], - weights=QuantizationArgs(num_bits=num_bits, symmetric=True), - input_activations=QuantizationArgs(num_bits=num_bits, symmetric=True), + weights=QuantizationArgs( + num_bits=num_bits, symmetric=True, scale_dtype=torch.float + ), + input_activations=QuantizationArgs( + num_bits=num_bits, symmetric=True, scale_dtype=torch.float + ), ) quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True) layer = Linear(4, 4) From e7475d23a853095cbcc982afbf202f266316495f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 14:39:02 -0500 Subject: [PATCH 16/31] update --- tests/test_linear/test_compressed_linear.py | 2 +- tests/test_quantization/lifecycle/test_enabled.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_linear/test_compressed_linear.py b/tests/test_linear/test_compressed_linear.py index fcf329369..1c1099c5b 100644 --- a/tests/test_linear/test_compressed_linear.py +++ b/tests/test_linear/test_compressed_linear.py @@ -34,7 +34,7 @@ def test_model_forward_pass(model_stub): """ # Load model model = AutoModelForCausalLM.from_pretrained( - model_stub, torch_dtype=torch.float16, device_map="auto" + model_stub, torch_dtype=torch.bfloat16, device_map="auto" ) # Load tokenizer diff --git a/tests/test_quantization/lifecycle/test_enabled.py b/tests/test_quantization/lifecycle/test_enabled.py index 87e25d551..24be64bf3 100644 --- a/tests/test_quantization/lifecycle/test_enabled.py +++ b/tests/test_quantization/lifecycle/test_enabled.py @@ -26,8 +26,8 @@ def test_quantization_enabled_disabled(): - inp = torch.randn(16) - model = Linear(16, 16) + inp = torch.randn(16, dtype=torch.bfloat16) + model = Linear(16, 16, dtype=torch.bfloat16) quantized_model = deepcopy(model) apply_quantization_config( model=quantized_model, From e7d6b529261ab76b61b881948d4a1ba139ddf6a7 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 14:54:13 -0500 Subject: [PATCH 17/31] update --- .../model_compressors/test_model_compressor.py | 15 +++++++++++++-- .../test_quantization/lifecycle/test_lifecycle.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 4592642af..942fd5283 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -22,6 +22,7 @@ from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization import ( + FP8_E4M3_DATA, QuantizationArgs, QuantizationConfig, QuantizationScheme, @@ -425,8 +426,18 @@ def test_multiple_quant_compressors(): format=CompressionFormat.float_quantized.value, ) - input_activations = QuantizationArgs(num_bits=4, type="float") - weights = QuantizationArgs(num_bits=4, type="float") + input_activations = QuantizationArgs( + num_bits=4, + type="float", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, + ) + weights = QuantizationArgs( + num_bits=4, + type="float", + scale_dtype=FP8_E4M3_DATA.dtype, + zp_dtype=FP8_E4M3_DATA.dtype, + ) scheme_nvfp4 = QuantizationScheme( targets=["Linear"], diff --git a/tests/test_quantization/lifecycle/test_lifecycle.py b/tests/test_quantization/lifecycle/test_lifecycle.py index 8f3e2dd01..3aefa4d24 100644 --- a/tests/test_quantization/lifecycle/test_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_lifecycle.py @@ -32,7 +32,7 @@ def test_lifecyle(mock_per_tensor_calibration, create_quantization_scheme): targets=["*"], ) - layer = Linear(4, 4) + layer = Linear(4, 4, dtype=torch.bfloat16) layer.weight.data *= 100 # updated layer keys check From 1970b264b5306e39740045615f431ff98981d03b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 14:55:51 -0500 Subject: [PATCH 18/31] fix comment --- src/compressed_tensors/quantization/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 4b4508d35..fa68f20f4 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -107,7 +107,7 @@ def calculate_qparams( ) scales = scales.to(quantization_args.scale_dtype) - # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped + # Clamp any potential 0s if scales.dtype == FP8_E4M3_DATA.dtype: # torch.clamp not supported for FP8 # use the next largest fp8 value from 0 From e8107e561b97509aab7f7d5ddd717eeb85506032 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 3 Nov 2025 20:22:15 -0500 Subject: [PATCH 19/31] update --- .../quantization/lifecycle/initialize.py | 13 ++++++++++++- src/compressed_tensors/quantization/quant_args.py | 5 ----- tests/test_linear/test_compressed_linear.py | 2 +- tests/test_quantization/lifecycle/test_forward.py | 8 ++------ 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 00369de26..7354a3a07 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -247,7 +247,18 @@ def initialize_qparams( else: assert False, f"Unknown strategy {strategy}" - # 2. Initializes scale/zp for the module + # 2. Identify quantization scale and zp dtype + if quantization_args.scale_dtype is None: + if observed_dtype not in [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ]: + observed_dtype = torch.float16 + quantization_args.scale_dtype = observed_dtype + + # 3. Initializes scale/zp for the module init_scale = Parameter( torch.empty(expected_shape, dtype=quantization_args.scale_dtype, device=device), requires_grad=False, diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 83bbfff9f..813daded3 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -269,7 +269,6 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": dynamic = model.dynamic observer = model.observer dynamic = model.dynamic - scale_dtype = model.scale_dtype zp_dtype = model.zp_dtype # infer strategy @@ -361,14 +360,10 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": if zp_dtype is None: zp_dtype = model.pytorch_dtype() - if scale_dtype is None: - scale_dtype = torch.bfloat16 - # write back modified values model.strategy = strategy model.observer = observer model.zp_dtype = zp_dtype - model.scale_dtype = scale_dtype return model def pytorch_dtype(self) -> torch.dtype: diff --git a/tests/test_linear/test_compressed_linear.py b/tests/test_linear/test_compressed_linear.py index 1c1099c5b..fcf329369 100644 --- a/tests/test_linear/test_compressed_linear.py +++ b/tests/test_linear/test_compressed_linear.py @@ -34,7 +34,7 @@ def test_model_forward_pass(model_stub): """ # Load model model = AutoModelForCausalLM.from_pretrained( - model_stub, torch_dtype=torch.bfloat16, device_map="auto" + model_stub, torch_dtype=torch.float16, device_map="auto" ) # Load tokenizer diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index 8d9b48a83..f3321cd40 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -64,12 +64,8 @@ def test_forward_quantize( num_bits = 8 quantization_scheme = create_quantization_scheme( targets=["*"], - weights=QuantizationArgs( - num_bits=num_bits, symmetric=True, scale_dtype=torch.float - ), - input_activations=QuantizationArgs( - num_bits=num_bits, symmetric=True, scale_dtype=torch.float - ), + weights=QuantizationArgs(num_bits=num_bits, symmetric=True), + input_activations=QuantizationArgs(num_bits=num_bits, symmetric=True), ) quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True) layer = Linear(4, 4) From e571a36dc9bbb9215d2e19765042d4c1565da779 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 4 Nov 2025 14:47:39 -0500 Subject: [PATCH 20/31] update --- src/compressed_tensors/config/format.py | 2 + .../quantization/lifecycle/forward.py | 7 +++ .../quantization/quant_scheme.py | 33 ++++++++++ .../quantization/utils/helpers.py | 4 ++ .../quantization/utils/mxfp4_utils.py | 62 ++++++++++++++----- .../test_utils/test_mxfp4_utils.py | 20 +++++- 6 files changed, 109 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 4f6610de3..5d0c11436 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -50,6 +50,8 @@ def _get_quant_compression_format( is_weight_only = weight_args is not None and input_args is None if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + if weight_args.group_size == 32: + return CompressionFormat.mxfp4_pack_quantized return CompressionFormat.nvfp4_pack_quantized if is_weight_only: # w4a16 and w8a16 diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index cfa3338dc..01d5338e7 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -28,6 +28,7 @@ from compressed_tensors.quantization.utils import ( calculate_range, compute_dynamic_scales_and_zp, + maybe_convert_from_mxfp4_scale, ) from torch.nn import Module @@ -255,6 +256,7 @@ def _process_quantization( scale=sb, zero_point=zb, global_scale=global_scale, + args=args, ) # restore original shape output = x_blocks.transpose(1, 2).reshape(original_shape) @@ -321,6 +323,7 @@ def _process_quantization( scale=scale.unsqueeze(-1), zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, global_scale=global_scale, + args=args, ) output = output.flatten(start_dim=-2) @@ -348,6 +351,7 @@ def _process_quantization( scale=scale, zero_point=zero_point, global_scale=global_scale, + args=args, ) return output @@ -468,6 +472,7 @@ def _quantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + scale = maybe_convert_from_mxfp4_scale(args=args, scale=scale) scaled = x / scale if zero_point is not None: @@ -491,6 +496,7 @@ def _quantize( def _dequantize( x_q: torch.Tensor, scale: torch.Tensor, + args: QuantizationArgs, zero_point: torch.Tensor = None, dtype: Optional[torch.dtype] = None, global_scale: Optional[torch.Tensor] = None, @@ -501,6 +507,7 @@ def _dequantize( if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale + scale = maybe_convert_from_mxfp4_scale(args=args, scale=scale) dequant_value = x_q.to(scale.dtype) if zero_point is not None: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index d31e133e8..4b468af61 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -192,6 +192,37 @@ def is_preset_scheme(name: str) -> bool: ), ) +MXFP4A16 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=32, + ) +) + +MXFP4 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=32, + ), + input_activations=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + dynamic=True, + symmetric=True, + group_size=32, + ), +) + + # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( @@ -343,4 +374,6 @@ def is_preset_scheme(name: str) -> bool: "FP8_BLOCK": FP8_BLOCK, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, + "MXFP4A16": MXFP4A16, + "MXFP4": MXFP4, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index fa68f20f4..613871c42 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -26,6 +26,9 @@ QuantizationType, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils.mxfp4_utils import ( + maybe_convert_to_mxfp4_scales, +) from compressed_tensors.utils import deprecated from loguru import logger from torch import FloatTensor, IntTensor, Tensor @@ -92,6 +95,7 @@ def calculate_qparams( # Conditionally scale the generated local scale by a global_scale scales = global_scale * scales + scales = maybe_convert_to_mxfp4_scales(args=quantization_args, scales=scales) if quantization_args.scale_dtype is not None: if torch.is_floating_point( torch.empty((), dtype=quantization_args.scale_dtype) diff --git a/src/compressed_tensors/quantization/utils/mxfp4_utils.py b/src/compressed_tensors/quantization/utils/mxfp4_utils.py index 17821ae72..7d04a1e92 100644 --- a/src/compressed_tensors/quantization/utils/mxfp4_utils.py +++ b/src/compressed_tensors/quantization/utils/mxfp4_utils.py @@ -13,16 +13,25 @@ # limitations under the License. import torch -from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA +from compressed_tensors.quantization.quant_args import ( + BFLOAT16_DATA, + FP4_E2M1_DATA, + QuantizationArgs, +) -__all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"] +__all__ = [ + "maybe_convert_from_mxfp4_scale", + "generate_mxfp4_scales", + "round_to_power_2", + "maybe_convert_to_mxfp4_scales", +] # Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501 -def convert_mxfp4_exp_scale( - scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16 +def maybe_convert_from_mxfp4_scale( + args: QuantizationArgs, scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16 ) -> torch.Tensor: """ Converts mxfp4 scales. Scales are powers of 2, with the @@ -32,10 +41,29 @@ def convert_mxfp4_exp_scale( :param scale: uint8 exponent scale :param dtype: dense dtype """ - assert scale.dtype == torch.uint8 - scale_exp = scale.to(torch.int32) - 127 - scale = 2.00 ** (scale_exp.to(torch.float)) - return scale.to(dtype) + is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32 + if is_mxfp4: + assert scale.dtype == torch.uint8 + scale_exp = scale.to(torch.int32) - 127 + scale = 2.00 ** (scale_exp.to(torch.float)) + return scale.to(dtype) + return scale + + +def maybe_convert_to_mxfp4_scales( + args: QuantizationArgs, scales: torch.Tensor +) -> torch.Tensor: + """ + Conver the scales to be mxfp4 compatible scales, if quant args are FP4 with group_size 32. + If not, return original scales + + :param args: quantization args + :param scales: scales to update + """ + is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32 + if is_mxfp4: + return generate_mxfp4_scales(x=scales) + return scales def round_to_power_2(x: torch.Tensor) -> torch.Tensor: @@ -72,12 +100,12 @@ def round_to_power_2(x: torch.Tensor) -> torch.Tensor: return block_max_uint.to(torch.uint16).view(torch.bfloat16) -def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: +def generate_mxfp4_scales(x: torch.Tensor, clamp: bool = False) -> torch.Tensor: """ Generate mxfp4 scales. The scales require the following steps 1. Round to the closest power of 2 2. Convert to exponent - 3. Store in uint8 + 3. Optionally, store in uint8 Called when calculating qparams using observers. @@ -89,9 +117,11 @@ def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: # Convert to exponent scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2 # Clamp and store in uint8, as expected by mxfp4 - scale_exp = torch.clamp( - scale_exp, - max=torch.iinfo(torch.uint8).max, - min=torch.iinfo(torch.uint8).min, - ) - return scale_exp.to(torch.uint8) + if clamp: + scale_exp = torch.clamp( + scale_exp, + max=torch.iinfo(torch.uint8).max, + min=torch.iinfo(torch.uint8).min, + ) + return scale_exp.to(torch.uint8) + return scale_exp diff --git a/tests/test_quantization/test_utils/test_mxfp4_utils.py b/tests/test_quantization/test_utils/test_mxfp4_utils.py index 723228bec..a19109a7c 100644 --- a/tests/test_quantization/test_utils/test_mxfp4_utils.py +++ b/tests/test_quantization/test_utils/test_mxfp4_utils.py @@ -14,8 +14,8 @@ import torch from compressed_tensors.quantization.utils import ( - convert_mxfp4_exp_scale, generate_mxfp4_scales, + maybe_convert_from_mxfp4_scale, round_to_power_2, ) @@ -61,6 +61,12 @@ def test_round_power_2(): def test_mxfp4_scales_e2e(): + from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, + ) + mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880)) x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16) @@ -71,8 +77,16 @@ def test_mxfp4_scales_e2e(): max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales_generated = generate_mxfp4_scales(block_max) - converted_ct = convert_mxfp4_exp_scale(scales_generated) + scales_generated = generate_mxfp4_scales(block_max, clamp=True) + args = QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + group_size=32, + scale_dtype=torch.uint8, + zp_dtype=torch.uint8, + ) + converted_ct = maybe_convert_from_mxfp4_scale(args=args, scale=scales_generated) scales_exp = torch.log2(converted_ct) block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2 From 68771b1430de6c0df3a48d6aea229e498c2fe616 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 4 Nov 2025 14:57:28 -0500 Subject: [PATCH 21/31] update --- src/compressed_tensors/quantization/quant_scheme.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 4b468af61..00ab51a4a 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import warnings from copy import deepcopy from typing import List, Optional +import torch from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, @@ -200,6 +200,8 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=32, + scale_dtype=torch.uint8, + scale_dtype=torch.uint8, ) ) @@ -210,7 +212,8 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.GROUP, symmetric=True, dynamic=False, - group_size=32, + scale_dtype=torch.uint8, + scale_dtype=torch.uint8, ), input_activations=QuantizationArgs( num_bits=4, @@ -219,6 +222,8 @@ def is_preset_scheme(name: str) -> bool: dynamic=True, symmetric=True, group_size=32, + scale_dtype=torch.uint8, + scale_dtype=torch.uint8, ), ) From bf0d0c63d2aa5e37a7eff0a07ce14e560cc16a85 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 4 Nov 2025 15:11:35 -0500 Subject: [PATCH 22/31] fix typo --- src/compressed_tensors/quantization/quant_scheme.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 00ab51a4a..7b6936851 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -201,7 +201,7 @@ def is_preset_scheme(name: str) -> bool: dynamic=False, group_size=32, scale_dtype=torch.uint8, - scale_dtype=torch.uint8, + zp_dtype=torch.uint8, ) ) @@ -213,7 +213,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, scale_dtype=torch.uint8, - scale_dtype=torch.uint8, + zp_dtype=torch.uint8, ), input_activations=QuantizationArgs( num_bits=4, @@ -223,7 +223,7 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, group_size=32, scale_dtype=torch.uint8, - scale_dtype=torch.uint8, + zp_dtype=torch.uint8, ), ) From da3ad9fe7271835fa484ccdbf3ee231984695bc9 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 4 Nov 2025 15:13:23 -0500 Subject: [PATCH 23/31] update --- src/compressed_tensors/quantization/quant_scheme.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 7b6936851..6e4e103c4 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -212,6 +212,7 @@ def is_preset_scheme(name: str) -> bool: strategy=QuantizationStrategy.GROUP, symmetric=True, dynamic=False, + group_size=32, scale_dtype=torch.uint8, zp_dtype=torch.uint8, ), From 94dbb582b3372a1d665b0efe43fe76a1cb595b54 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 4 Nov 2025 15:26:31 -0500 Subject: [PATCH 24/31] updatE --- src/compressed_tensors/quantization/utils/helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 613871c42..ffbaf74ee 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -89,13 +89,13 @@ def calculate_qparams( if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) + #scales = max_val_pos / (float(bit_range) / 2) if global_scale is not None: # Conditionally scale the generated local scale by a global_scale scales = global_scale * scales - scales = maybe_convert_to_mxfp4_scales(args=quantization_args, scales=scales) + scales = maybe_convert_to_mxfp4_scales(args=quantization_args, scales=max_val_pos) if quantization_args.scale_dtype is not None: if torch.is_floating_point( torch.empty((), dtype=quantization_args.scale_dtype) @@ -120,8 +120,8 @@ def calculate_qparams( torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device), scales, ) - else: - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + #else: + # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: From 204a3de4eac6dbd41fe79af637532b465b7c2951 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 18:30:01 -0500 Subject: [PATCH 25/31] rebase fixes --- .../quantization/lifecycle/initialize.py | 2 +- .../quantization/quant_config.py | 33 ------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 7ae378026..8c1b251c5 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -259,7 +259,7 @@ def initialize_qparams( # 3. Initializes scale/zp for the module init_scale = Parameter( - torch.empty(expected_shape, dtype=quantization_args.scale_dtype, device=device), + torch.empty(expected_shape, dtype=scale_dtype, device=device), requires_grad=False, ) register_offload_parameter(module, f"{base_name}_scale", init_scale) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index f856f9783..8ab4f1d7b 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -278,38 +278,5 @@ def requires_calibration_data(self): return False - def model_dump(self, *args, **kwargs): - # Call the parent dump first - data = super().model_dump(*args, **kwargs) - - def _convert_dtypes_in_dict(d): - for k, v in d.items(): - if (k == "zp_dtype" and d.get("symmetric")) or ( - k == "scale_dtype" and d.get("dynamic") in (True, "local") - ): - d[k] = None - elif isinstance(v, dict): - _convert_dtypes_in_dict(v) - return d - - scheme = "config_groups" - if data.get(scheme): - for _, v in data[scheme].items(): - weight = v.get("weights") - input = v.get("input_activations") - output = v.get("output_activations") - - args = [weight, input, output] - for arg in args: - if arg is not None: - _convert_dtypes_in_dict(arg) - - scheme = "kv_cache_scheme" - kv_cache_data = data.get(scheme) - if kv_cache_data: - _convert_dtypes_in_dict(kv_cache_data) - - return data - # TODO set `extra="forbid"` when upstream transformers is compatible model_config = ConfigDict(extra="ignore") From 34bc8dfe119220d10b9b573f2b382be495dafc58 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 18:31:07 -0500 Subject: [PATCH 26/31] more rebase fix --- tests/test_quantization/lifecycle/test_enabled.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_enabled.py b/tests/test_quantization/lifecycle/test_enabled.py index 24be64bf3..87e25d551 100644 --- a/tests/test_quantization/lifecycle/test_enabled.py +++ b/tests/test_quantization/lifecycle/test_enabled.py @@ -26,8 +26,8 @@ def test_quantization_enabled_disabled(): - inp = torch.randn(16, dtype=torch.bfloat16) - model = Linear(16, 16, dtype=torch.bfloat16) + inp = torch.randn(16) + model = Linear(16, 16) quantized_model = deepcopy(model) apply_quantization_config( model=quantized_model, From af074e58a20be8881d989ce293637e39417dd345 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 19:32:32 -0500 Subject: [PATCH 27/31] update --- .../quantized_compressors/fp4_quantized.py | 21 +++++++++++++++++-- .../quantization/lifecycle/forward.py | 5 ----- .../quantization/utils/helpers.py | 7 ++++++- .../quantization/utils/mxfp4_utils.py | 1 - 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index b9bd8edef..e83133928 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -81,6 +81,14 @@ def compression_param_info( } return output + def compress_scale( + self, + scale: Tensor, + quantization_args: QuantizationArgs, + ) -> Dict[str, torch.Tensor]: + assert quantization_args.scale_dtype is not None + return scale.to(quantization_args.scale_dtype) + def compress_weight( self, weight: Tensor, @@ -103,7 +111,9 @@ def compress_weight( if device is not None: weight_packed = weight_packed.to(device) compressed_dict["weight_packed"] = weight_packed - compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype) + compressed_dict["weight_scale"] = self.compress_scale( + scale=scale, quantization_args=quantization_args + ) return compressed_dict def decompress_weight( @@ -130,7 +140,14 @@ class MXFP4PackedCompressor(NVFP4PackedCompressor): Alias for mxfp4 quantized models """ - pass + def compress_scale( + self, + scale: Tensor, + quantization_args: QuantizationArgs, + ) -> Dict[str, torch.Tensor]: + assert quantization_args.scale_dtype is not None + scale_exp = 127 + torch.floor(torch.log2(scale)).to(torch.int32) - 2 + return scale_exp.to(quantization_args.scale_dtype) @torch.compile(fullgraph=True, dynamic=True) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e93c1fc30..ee1f5fd59 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -256,7 +256,6 @@ def _process_quantization( scale=sb, zero_point=zb, global_scale=global_scale, - args=args, ) # restore original shape output = x_blocks.transpose(1, 2).reshape(original_shape) @@ -323,7 +322,6 @@ def _process_quantization( scale=scale.unsqueeze(-1), zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, global_scale=global_scale, - args=args, ) output = output.flatten(start_dim=-2) @@ -351,7 +349,6 @@ def _process_quantization( scale=scale, zero_point=zero_point, global_scale=global_scale, - args=args, ) return output @@ -493,7 +490,6 @@ def _quantize( def _dequantize( x_q: torch.Tensor, scale: torch.Tensor, - args: QuantizationArgs, zero_point: torch.Tensor = None, dtype: Optional[torch.dtype] = None, global_scale: Optional[torch.Tensor] = None, @@ -504,7 +500,6 @@ def _dequantize( if global_scale is not None: scale = scale / global_scale - scale = maybe_convert_from_mxfp4_scale(args=args, scale=scale) dequant_value = x_q.to(scale.dtype) if zero_point is not None: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 9789da6f9..b330461ba 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -28,6 +28,7 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils.mxfp4_utils import ( + maybe_convert_from_mxfp4_scale, maybe_convert_to_mxfp4_scales, ) from compressed_tensors.utils import deprecated @@ -91,7 +92,8 @@ def calculate_qparams( # 1. Generate scale and zero-point if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) + # scales = max_val_pos / (float(bit_range) / 2) + scales = maybe_convert_to_mxfp4_scales(max_val_pos) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: if ( @@ -115,6 +117,9 @@ def calculate_qparams( scales, dtype=quantization_args.scale_dtype ) + # Optionally remove exponent + scales = maybe_convert_from_mxfp4_scale(quantization_args, scales) + # 4. Update any 0s with small values to # prevent div by 0 eps = _get_dtype_eps( diff --git a/src/compressed_tensors/quantization/utils/mxfp4_utils.py b/src/compressed_tensors/quantization/utils/mxfp4_utils.py index 7d04a1e92..3eb08e8da 100644 --- a/src/compressed_tensors/quantization/utils/mxfp4_utils.py +++ b/src/compressed_tensors/quantization/utils/mxfp4_utils.py @@ -43,7 +43,6 @@ def maybe_convert_from_mxfp4_scale( """ is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32 if is_mxfp4: - assert scale.dtype == torch.uint8 scale_exp = scale.to(torch.int32) - 127 scale = 2.00 ** (scale_exp.to(torch.float)) return scale.to(dtype) From a6dc025e423c9f20cb16ce9f1a7b4ea662f85e75 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 19:33:37 -0500 Subject: [PATCH 28/31] update --- src/compressed_tensors/quantization/lifecycle/forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index ee1f5fd59..fe637013f 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -469,7 +469,6 @@ def _quantize( if global_scale is not None: scale = scale / global_scale - scale = maybe_convert_from_mxfp4_scale(args=args, scale=scale) scaled = x / scale if zero_point is not None: From c3b1e95ae23e691bb11cdaed01f79281583a3dbd Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 10 Nov 2025 19:34:22 -0500 Subject: [PATCH 29/31] update --- src/compressed_tensors/quantization/lifecycle/forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index fe637013f..573826e18 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -28,7 +28,6 @@ from compressed_tensors.quantization.utils import ( calculate_range, compute_dynamic_scales_and_zp, - maybe_convert_from_mxfp4_scale, ) from torch.nn import Module From 9dfb31c572b0fd1373833771e13976d3379b7625 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 17 Nov 2025 15:12:53 -0500 Subject: [PATCH 30/31] update --- .../quantization/utils/helpers.py | 19 ++++--- .../quantization/utils/mxfp4_utils.py | 51 +++++-------------- .../test_utils/test_mxfp4_utils.py | 10 ++-- 3 files changed, 32 insertions(+), 48 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index b330461ba..59c5f245a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -28,8 +28,9 @@ ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils.mxfp4_utils import ( - maybe_convert_from_mxfp4_scale, - maybe_convert_to_mxfp4_scales, + generate_mxfp4_scales, + maybe_convert_from_mxfp4_exp, + should_generatre_mxfp4_scales, ) from compressed_tensors.utils import deprecated from loguru import logger @@ -92,8 +93,10 @@ def calculate_qparams( # 1. Generate scale and zero-point if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - # scales = max_val_pos / (float(bit_range) / 2) - scales = maybe_convert_to_mxfp4_scales(max_val_pos) + if should_generatre_mxfp4_scales(args=quantization_args): + scales = generate_mxfp4_scales(x=max_val_pos) + else: + scales = max_val_pos / (float(bit_range) / 2) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: if ( @@ -117,10 +120,10 @@ def calculate_qparams( scales, dtype=quantization_args.scale_dtype ) - # Optionally remove exponent - scales = maybe_convert_from_mxfp4_scale(quantization_args, scales) + # 4. Optionally remove exponent + scales = maybe_convert_from_mxfp4_exp(quantization_args, scales) - # 4. Update any 0s with small values to + # 5. Update any 0s with small values to # prevent div by 0 eps = _get_dtype_eps( dtype=quantization_args.scale_dtype @@ -133,7 +136,7 @@ def calculate_qparams( scales, ) - # 5. Round the zp to zp_dtype + # 6. Round the zp to zp_dtype zero_points = round_to_quantized_type_dtype( zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False ) diff --git a/src/compressed_tensors/quantization/utils/mxfp4_utils.py b/src/compressed_tensors/quantization/utils/mxfp4_utils.py index 3eb08e8da..21dd841fb 100644 --- a/src/compressed_tensors/quantization/utils/mxfp4_utils.py +++ b/src/compressed_tensors/quantization/utils/mxfp4_utils.py @@ -21,17 +21,21 @@ __all__ = [ - "maybe_convert_from_mxfp4_scale", + "maybe_convert_from_mxfp4_exp", "generate_mxfp4_scales", "round_to_power_2", - "maybe_convert_to_mxfp4_scales", + "should_generatre_mxfp4_scales", ] # Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501 -def maybe_convert_from_mxfp4_scale( - args: QuantizationArgs, scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16 +def should_generatre_mxfp4_scales(args: QuantizationArgs): + return args.num_bits == 4 and args.type == "float" and args.group_size == 32 + + +def maybe_convert_from_mxfp4_exp( + args: QuantizationArgs, scale: torch.Tensor ) -> torch.Tensor: """ Converts mxfp4 scales. Scales are powers of 2, with the @@ -41,30 +45,14 @@ def maybe_convert_from_mxfp4_scale( :param scale: uint8 exponent scale :param dtype: dense dtype """ - is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32 - if is_mxfp4: + original_dtype = scale.dtype + if should_generatre_mxfp4_scales(args): scale_exp = scale.to(torch.int32) - 127 scale = 2.00 ** (scale_exp.to(torch.float)) - return scale.to(dtype) + return scale.to(original_dtype) return scale -def maybe_convert_to_mxfp4_scales( - args: QuantizationArgs, scales: torch.Tensor -) -> torch.Tensor: - """ - Conver the scales to be mxfp4 compatible scales, if quant args are FP4 with group_size 32. - If not, return original scales - - :param args: quantization args - :param scales: scales to update - """ - is_mxfp4 = args.num_bits == 4 and args.type == "float" and args.group_size == 32 - if is_mxfp4: - return generate_mxfp4_scales(x=scales) - return scales - - def round_to_power_2(x: torch.Tensor) -> torch.Tensor: """ Round values to the closest power of 2. @@ -99,28 +87,17 @@ def round_to_power_2(x: torch.Tensor) -> torch.Tensor: return block_max_uint.to(torch.uint16).view(torch.bfloat16) -def generate_mxfp4_scales(x: torch.Tensor, clamp: bool = False) -> torch.Tensor: +def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: """ Generate mxfp4 scales. The scales require the following steps 1. Round to the closest power of 2 2. Convert to exponent - 3. Optionally, store in uint8 Called when calculating qparams using observers. :param x: tensor to round to closest power of 2 - :returns uint8 scales as exponents + :returns scales as exponents """ # Round to closest power of 2 scale_power_2 = round_to_power_2(x) - # Convert to exponent - scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2 - # Clamp and store in uint8, as expected by mxfp4 - if clamp: - scale_exp = torch.clamp( - scale_exp, - max=torch.iinfo(torch.uint8).max, - min=torch.iinfo(torch.uint8).min, - ) - return scale_exp.to(torch.uint8) - return scale_exp + return 127 + torch.floor(torch.log2(scale_power_2)) - 2 diff --git a/tests/test_quantization/test_utils/test_mxfp4_utils.py b/tests/test_quantization/test_utils/test_mxfp4_utils.py index a19109a7c..15ac84801 100644 --- a/tests/test_quantization/test_utils/test_mxfp4_utils.py +++ b/tests/test_quantization/test_utils/test_mxfp4_utils.py @@ -13,9 +13,10 @@ # limitations under the License. import torch +from compressed_tensors.quantization import round_to_quantized_type_dtype from compressed_tensors.quantization.utils import ( generate_mxfp4_scales, - maybe_convert_from_mxfp4_scale, + maybe_convert_from_mxfp4_exp, round_to_power_2, ) @@ -77,7 +78,6 @@ def test_mxfp4_scales_e2e(): max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales_generated = generate_mxfp4_scales(block_max, clamp=True) args = QuantizationArgs( num_bits=4, type=QuantizationType.FLOAT, @@ -86,7 +86,11 @@ def test_mxfp4_scales_e2e(): scale_dtype=torch.uint8, zp_dtype=torch.uint8, ) - converted_ct = maybe_convert_from_mxfp4_scale(args=args, scale=scales_generated) + + scales = generate_mxfp4_scales(block_max) + scales = round_to_quantized_type_dtype(scales, dtype=args.scale_dtype) + + converted_ct = maybe_convert_from_mxfp4_exp(args=args, scale=scales) scales_exp = torch.log2(converted_ct) block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2 From 7a2d7c818b65986d995c2cccb8a11560a7107f35 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 17 Nov 2025 16:03:49 -0500 Subject: [PATCH 31/31] dequant scales not support --- .../compressors/quantized_compressors/fp4_quantized.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index e83133928..dd3c2a463 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -149,6 +149,13 @@ def compress_scale( scale_exp = 127 + torch.floor(torch.log2(scale)).to(torch.int32) - 2 return scale_exp.to(quantization_args.scale_dtype) + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + raise NotImplementedError("MXFP4 Decompression is currently not supported") + @torch.compile(fullgraph=True, dynamic=True) def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: