Skip to content

Commit fd88038

Browse files
committed
fix tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 4f366d7 commit fd88038

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

src/llmcompressor/entrypoints/weights_ptq/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,9 @@ def _process_file(
9191

9292
for name in list(tensors.keys()):
9393
module_name, param_name = name.rsplit(".", 1)
94+
is_linear_weight = param_name == "weight" and not module_name.endswith("norm")
9495
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
95-
is_weight = param_name == "weight"
96-
if is_ignored or not is_weight:
97-
print(f"skip {name}")
96+
if not is_linear_weight or is_ignored:
9897
continue
9998

10099
# 1. initialize module with qparams (on device)

src/llmcompressor/entrypoints/weights_ptq/lifecycle.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
update_weight_zp_scale,
1616
)
1717

18-
__all__ = ["initialize_quantized_linear", "calibrate_weights", "compress_module"]
18+
__all__ = [
19+
"initialize_quantized_linear",
20+
"calibrate_weights",
21+
"compress_module",
22+
]
1923

2024

2125
def initialize_quantized_linear(
@@ -58,6 +62,8 @@ def compress_module(module: torch.nn.Linear):
5862
global_scale=getattr(module, "weight_global_scale", None),
5963
)
6064

65+
# `compress_weight` is a messy api
66+
delattr(module, "weight")
6167
for key, value in data.items():
6268
if hasattr(module, key):
6369
getattr(module, key).data = value

tests/llmcompressor/pipelines/test_ptq_weights.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111
from tests.testing_utils import requires_gpu
1212

1313

14+
def _get_tiny_w4a16_quant():
15+
return QuantizationScheme(
16+
targets=["Linear"],
17+
weights=QuantizationArgs(
18+
num_bits=4,
19+
type="int",
20+
strategy="group",
21+
group_size=16,
22+
symmetric=True,
23+
dynamic=False,
24+
),
25+
)
26+
27+
1428
def _get_tiny_block_quant():
1529
return QuantizationScheme(
1630
targets=["Linear"],
@@ -26,10 +40,12 @@ def _get_tiny_block_quant():
2640

2741

2842
@requires_gpu
29-
@pytest.mark.parametrize("scheme", ["FP8_dynamic", _get_tiny_block_quant()])
43+
@pytest.mark.parametrize(
44+
"scheme", [_get_tiny_w4a16_quant(), "FP8_dynamic", _get_tiny_block_quant()]
45+
)
3046
def test_weights_ptq_e2e(scheme, tmp_path):
3147
model = "nm-testing/tinysmokellama-3.2"
32-
ignore = ["model.embed_tokens", "lm_head", "re:.*norm$"]
48+
ignore = ["model.embed_tokens", "lm_head"]
3349
device = "cuda:0"
3450

3551
ptq_outdir = tmp_path / "weights_out"

0 commit comments

Comments
 (0)