From 7456c8c6d25364a6f8f62085d2e83c88a57bb7be Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 2 Apr 2026 17:44:41 +0800 Subject: [PATCH 1/2] Update alg_ext.py --- auto_round/alg_ext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index 4d74f8bd9..52a11c28d 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -72,7 +72,7 @@ def get_abs_top_percent_mask(x: torch.Tensor, percent: float = 1.0): inv_mask (torch.BoolTensor): Inverse of mask. """ flat = x.view(-1) - k = max(1, int(flat.numel() * percent / 1000)) # 至少选1个 + k = max(1, int(flat.numel() * percent / 1000)) _, idx = torch.topk(torch.abs(flat), k) mask = torch.zeros_like(flat, dtype=torch.bool) @@ -612,7 +612,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u # iscale_new = factor / (rmax - rmin + 1e-8) scale_new = (rmax - rmin) / factor iscale_new = get_reciprocal(scale_new) - quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin)), minq, maxq) + quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin) + v), minq, maxq) mul_weights_quant_data = weights * quant_data_new sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True) From c52fa386b1d7a63fc76b38182c7b0c224a8ba1c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:45:20 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/alg_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index 52a11c28d..dc9136cae 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -72,7 +72,7 @@ def get_abs_top_percent_mask(x: torch.Tensor, percent: float = 1.0): inv_mask (torch.BoolTensor): Inverse of mask. """ flat = x.view(-1) - k = max(1, int(flat.numel() * percent / 1000)) + k = max(1, int(flat.numel() * percent / 1000)) _, idx = torch.topk(torch.abs(flat), k) mask = torch.zeros_like(flat, dtype=torch.bool)