From 23b596eb2202b121fadb5cc30a764c554968f2fc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 19 Nov 2025 23:00:09 +0000 Subject: [PATCH] implement _clamp_scale_values Signed-off-by: Kyle Sayers --- .../quantization/quant_args.py | 1 + .../quantization/utils/helpers.py | 36 +++++++++---------- .../test_utils/test_helpers.py | 17 +++++++++ 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index e6e0def3b..5a37572a3 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -84,6 +84,7 @@ class FP8_E4M3_DATA(FloatArgs): max = torch.finfo(torch.float8_e4m3fn).max min = torch.finfo(torch.float8_e4m3fn).min dtype = torch.float8_e4m3fn + eps = 0.125 class BFLOAT16_DATA(FloatArgs): diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 45a4ef83c..477ba57a7 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -79,6 +79,7 @@ def calculate_qparams( # 0.0 must always be representable within the quantized range min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + observed_dtype = min_vals.dtype device = min_vals.device @@ -114,16 +115,8 @@ def calculate_qparams( # 4. Update any 0s with small values to # prevent div by 0 - eps = _get_dtype_eps( - dtype=quantization_args.scale_dtype - if quantization_args.scale_dtype is not None - else scales.dtype - ) - scales = torch.where( - scales == 0, - torch.tensor(eps, dtype=scales.dtype, device=device), - scales, - ) + scale_dtype = quantization_args.scale_dtype or observed_dtype + _clamp_scale_values(scales, scale_dtype) # 5. Round the zp to zp_dtype zero_points = round_to_quantized_type_dtype( @@ -422,6 +415,7 @@ def generate_gparam( max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val)) max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos + global_scale = _clamp_scale_values(global_scale) return global_scale.to(dtype).reshape([1]) @@ -448,12 +442,16 @@ def strategy_cdiv( return dividend -def _get_dtype_eps(dtype: torch.dtype) -> float: - if dtype == FP8_E4M3_DATA.dtype: - return 0.125 - elif dtype == FP4_E2M1_DATA.dtype: - return 0.25 - elif torch.is_floating_point(torch.tensor([], dtype=dtype)): - return torch.finfo(dtype).eps - else: - return 1 +def _clamp_scale_values(tensor: torch.Tensor, dtype: torch.dtype) -> float: + # note that scales always have a torch dtype (don't support FP4 scales atm) + assert dtype.is_floating_point, "Non-floating point dtypes are not supported" + info = torch.finfo(dtype) + tensor = torch.nan_to_num( + tensor, + nan=info.eps, + posinf=info.max, + neginf=info.min, + ) + tensor = torch.where(tensor == 0, info.eps, tensor) + + return tensor diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index c7fb4b452..4ec82dcb0 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -25,6 +25,7 @@ compute_dynamic_scales_and_zp, generate_gparam, ) +from compressed_tensors.quantization.utils.helpers import _clamp_scale_values @pytest.mark.parametrize( @@ -105,3 +106,19 @@ def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape): scale, zp = compute_dynamic_scales_and_zp(value, args, module=torch.nn.Module()) assert scale.shape == exp_shape assert zp.shape == exp_shape + + +@pytest.mark.unit +@pytest.mark.parametrize("fp_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32, torch.float8_e4m3fn]) +def test_clamp_scale_values(fp_dtype, dtype): + info = torch.finfo(dtype) + value = torch.tensor( + [1.0, -1.0, 0.0, torch.inf, -torch.inf, torch.nan], dtype=fp_dtype + ) + exp = torch.tensor( + [1.0, -1.0, info.eps, info.max, info.min, info.eps], dtype=fp_dtype + ) + + clamped = _clamp_scale_values(value, dtype) + assert torch.equal(clamped, exp)