diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index 4d74f8bd9..dc9136cae 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -72,7 +72,7 @@ def get_abs_top_percent_mask(x: torch.Tensor, percent: float = 1.0): inv_mask (torch.BoolTensor): Inverse of mask. """ flat = x.view(-1) - k = max(1, int(flat.numel() * percent / 1000)) # 至少选1个 + k = max(1, int(flat.numel() * percent / 1000)) _, idx = torch.topk(torch.abs(flat), k) mask = torch.zeros_like(flat, dtype=torch.bool) @@ -612,7 +612,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u # iscale_new = factor / (rmax - rmin + 1e-8) scale_new = (rmax - rmin) / factor iscale_new = get_reciprocal(scale_new) - quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin)), minq, maxq) + quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin) + v), minq, maxq) mul_weights_quant_data = weights * quant_data_new sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True)