Skip to content

Commit 6fe9db9

Browse files
committed
ptq_weights
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent c254c19 commit 6fe9db9

File tree

9 files changed

+534
-6
lines changed

9 files changed

+534
-6
lines changed

src/llmcompressor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@
2626
create_session,
2727
reset_session,
2828
)
29-
from llmcompressor.entrypoints import Oneshot, oneshot, train
29+
from llmcompressor.entrypoints import Oneshot, oneshot, train, ptq_weights

src/llmcompressor/entrypoints/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99

1010
from .oneshot import Oneshot, oneshot
1111
from .train import train
12+
from .weights_ptq import ptq_weights
1213
from .utils import post_process, pre_process

src/llmcompressor/entrypoints/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from llmcompressor.pytorch.model_load.helpers import parse_dtype
3535
from llmcompressor.transformers.compression.compressed_tensors_utils import (
3636
modify_save_pretrained,
37-
untie_word_embeddings,
3837
)
3938
from llmcompressor.transformers.utils.helpers import (
4039
detect_last_checkpoint,
@@ -93,10 +92,6 @@ def pre_process(
9392
f"`oneshot`/`train`.\nInitialization Error: {e}"
9493
)
9594

96-
# untie tie_word_embeddings weights
97-
if not model_args.tie_word_embeddings:
98-
untie_word_embeddings(model_args.model)
99-
10095
# wrap model.save_pretrained
10196
modify_save_pretrained(model_args.model)
10297

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
import shutil
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
import torch
8+
import tqdm
9+
from compressed_tensors.quantization import QuantizationScheme
10+
from compressed_tensors.utils.match import _match_name
11+
from loguru import logger
12+
from safetensors.torch import load_file, save_file
13+
14+
from llmcompressor.entrypoints.weights_ptq.helpers import (
15+
gpu_if_available,
16+
validate_scheme,
17+
)
18+
from llmcompressor.entrypoints.weights_ptq.lifecycle import (
19+
calibrate_weights,
20+
compress_module,
21+
initialize_quantized_linear,
22+
)
23+
from llmcompressor.entrypoints.weights_ptq.model_utils import (
24+
get_checkpoint_files,
25+
is_weights_file,
26+
)
27+
from llmcompressor.entrypoints.weights_ptq.save_utils import (
28+
update_config,
29+
update_safetensors_index,
30+
)
31+
32+
__all__ = ["ptq_weights"]
33+
34+
35+
def ptq_weights(
36+
model_stub: str | os.PathLike,
37+
save_directory: str | os.PathLike,
38+
scheme: QuantizationScheme | str,
39+
ignore: Optional[list[str]] = None,
40+
max_workers: int = 1,
41+
device: Optional[torch.device | str] = None,
42+
):
43+
# validate arguments
44+
model_files = get_checkpoint_files(model_stub)
45+
scheme_name, scheme = validate_scheme(scheme)
46+
device = gpu_if_available(device)
47+
48+
# 0. collect safetensors files, copy files
49+
jobs = []
50+
for file_path, resolved_path in model_files:
51+
save_path = Path(save_directory) / file_path
52+
53+
if file_path.endswith("safetensors"):
54+
jobs.append(
55+
(_process_file, resolved_path, save_path, scheme, ignore, device)
56+
)
57+
58+
else:
59+
if is_weights_file(file_path):
60+
logger.warning(f"Skipping weights file {file_path}")
61+
save_path.parent.mkdir(parents=True, exist_ok=True)
62+
logger.info(f"Copying {file_path} {save_path}")
63+
shutil.copyfile(resolved_path, save_path)
64+
65+
# 1-4. quantize and compress weights
66+
with ThreadPoolExecutor(max_workers) as executor:
67+
futures = [executor.submit(*job) for job in jobs]
68+
69+
total_size = 0
70+
weight_map = dict()
71+
for future in tqdm.tqdm(
72+
as_completed(futures), total=len(futures), desc="Quantizing"
73+
):
74+
_total_size, _weight_map = future.result()
75+
total_size += _total_size
76+
weight_map.update(_weight_map)
77+
78+
# 5. update config and safetensors index
79+
update_config(save_directory, scheme_name, scheme, ignore)
80+
update_safetensors_index(save_directory, total_size, weight_map)
81+
82+
83+
def _process_file(
84+
file_path: str | os.PathLike,
85+
save_path: str | os.PathLike,
86+
scheme: QuantizationScheme,
87+
ignore: str | list[str],
88+
device: str | torch.device,
89+
) -> tuple[int, dict[str, str]]:
90+
tensors = load_file(file_path)
91+
92+
for name in list(tensors.keys()):
93+
module_name, param_name = name.rsplit(".", 1)
94+
is_ignored = any(_match_name(module_name, ign) for ign in ignore)
95+
is_weight = param_name == "weight"
96+
if is_ignored or not is_weight:
97+
print(f"skip {name}")
98+
continue
99+
100+
# 1. initialize module with qparams (on device)
101+
module = initialize_quantized_linear(tensors[name], scheme, device)
102+
103+
# 2. calibrate weight qparams
104+
calibrate_weights(module)
105+
106+
# 3. compress module using qparams
107+
compress_module(module)
108+
109+
# 4. save compressed data (on cpu)
110+
del tensors[name]
111+
prefix = module_name + "."
112+
for key, value in module.state_dict(prefix=prefix).items():
113+
tensors[key] = value.to("cpu")
114+
115+
save_file(tensors, save_path)
116+
total_size = sum(tensor.nbytes for tensor in tensors.values())
117+
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
118+
return total_size, weight_map
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Optional
2+
3+
import torch
4+
from compressed_tensors.quantization import QuantizationScheme, preset_name_to_scheme
5+
from compressed_tensors.utils import getattr_chain
6+
from compressed_tensors.utils.match import _match_name
7+
from loguru import logger
8+
9+
__all__ = ["validate_scheme", "gpu_if_available", "is_match_name"]
10+
11+
12+
def validate_scheme(scheme: QuantizationScheme) -> tuple[str, QuantizationScheme]:
13+
# treat strings as preset schemes
14+
if isinstance(scheme, str):
15+
scheme_name, scheme = scheme, preset_name_to_scheme(scheme, [])
16+
else:
17+
scheme_name = "config_group_0"
18+
19+
# weight quantization must be provided
20+
if scheme.weights is None:
21+
raise ValueError(
22+
"Must provide a weights quanitization scheme to perform weights-only PTQ"
23+
)
24+
25+
# activation quantization must be dynamic
26+
input_dynamic = getattr_chain(scheme, "input_activations.dynamic", True)
27+
output_dynamic = getattr_chain(scheme, "output_activations.dynamic", True)
28+
if input_dynamic is not True or output_dynamic is not True:
29+
raise ValueError(
30+
"Weights-only PTQ cannot calibrate activations. "
31+
"Please use `oneshot` instead."
32+
)
33+
34+
# override with static observers
35+
if scheme.weights.observer in ("minmax", "mse"):
36+
new_observer = f"static_{scheme.weights.observer}"
37+
logger.warning(
38+
f"Scheme uses {scheme.weights.observer} weight observer. "
39+
f"Using {new_observer} instead"
40+
)
41+
scheme.weights.observer = new_observer
42+
43+
# target all modules; filter by ignore list
44+
# technically this should be "re:.*", but vllm's
45+
# ct moe layer has a hard coded check for "Linear"
46+
scheme.targets = ["Linear"]
47+
return scheme_name, scheme
48+
49+
50+
def gpu_if_available(device: torch.device | str | None) -> torch.device:
51+
if device is not None:
52+
return torch.device(device)
53+
54+
elif torch.cuda.is_available():
55+
return torch.device("cuda:0")
56+
57+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
58+
return torch.device("xpu:0")
59+
60+
else:
61+
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
62+
return torch.device("cpu")
63+
64+
65+
def is_match_name(
66+
name: str, targets: list[str], ignore: Optional[str | list[str]] = None
67+
) -> bool:
68+
targets = targets if isinstance(targets, list) else [targets]
69+
ignore = ignore if isinstance(ignore, list) else [ignore]
70+
71+
matches_target = any(_match_name(name, target) for target in targets)
72+
matches_ignore = any(_match_name(name, ign) for ign in ignore)
73+
74+
return matches_target and not matches_ignore
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
from compressed_tensors.compressors import BaseCompressor
3+
from compressed_tensors.config.format import _get_quant_compression_format
4+
from compressed_tensors.quantization import (
5+
QuantizationScheme,
6+
QuantizationStrategy,
7+
initialize_module_for_quantization,
8+
)
9+
10+
from llmcompressor.modifiers.quantization.calibration import (
11+
apply_calibration_status,
12+
freeze_module_quantization,
13+
initialize_observer,
14+
update_weight_global_scale,
15+
update_weight_zp_scale,
16+
)
17+
18+
__all__ = ["initialize_quantized_linear", "calibrate_weights", "compress_module"]
19+
20+
21+
def initialize_quantized_linear(
22+
weight: torch.Tensor, scheme: QuantizationScheme, device: str | torch.device
23+
) -> torch.nn.Module:
24+
out_features, in_features = weight.shape
25+
module = torch.nn.Linear(
26+
in_features, out_features, bias=False, device=device, dtype=weight.dtype
27+
)
28+
module.weight.data.copy_(weight)
29+
initialize_module_for_quantization(module, scheme, force_zero_point=False)
30+
31+
return module
32+
33+
34+
def calibrate_weights(module: torch.nn.Linear):
35+
scheme: QuantizationScheme = getattr(module, "quantization_scheme")
36+
initialize_observer(module, "weight")
37+
38+
apply_calibration_status(module)
39+
if scheme.weights.strategy == QuantizationStrategy.TENSOR_GROUP:
40+
update_weight_global_scale(module)
41+
update_weight_zp_scale(module)
42+
43+
freeze_module_quantization(module)
44+
45+
46+
def compress_module(module: torch.nn.Linear):
47+
scheme: QuantizationScheme = getattr(module, "quantization_scheme")
48+
49+
format = _get_quant_compression_format(scheme.input_activations, scheme.weights)
50+
scheme.format = format.value
51+
52+
compressor = BaseCompressor.load_from_registry(format.value)
53+
data = compressor.compress_weight(
54+
module.weight,
55+
quantization_args=scheme.weights,
56+
scale=getattr(module, "weight_scale"),
57+
zero_point=getattr(module, "weight_zero_point", None),
58+
global_scale=getattr(module, "weight_global_scale", None),
59+
)
60+
61+
for key, value in data.items():
62+
if hasattr(module, key):
63+
getattr(module, key).data = value
64+
else:
65+
module.register_parameter(
66+
key, torch.nn.Parameter(value, requires_grad=False)
67+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
3+
from huggingface_hub import list_repo_files
4+
from transformers.utils.hub import cached_file
5+
6+
__all__ = ["get_checkpoint_files", "is_weights_file"]
7+
8+
weights_files = [
9+
".bin",
10+
".safetensors",
11+
".pth",
12+
".msgpack",
13+
".pt",
14+
]
15+
16+
17+
def is_weights_file(file_name: str) -> bool:
18+
return any(file_name.endswith(suffix) for suffix in weights_files)
19+
20+
21+
def get_checkpoint_files(model_stub: str | os.PathLike) -> list[str]:
22+
# In the future, this function can accept and pass download kwargs to cached_file
23+
24+
if os.path.exists(model_stub):
25+
file_paths = walk_file_paths(model_stub, ignore=".cache")
26+
else:
27+
file_paths = list_repo_files(model_stub)
28+
29+
return [(file_path, cached_file(model_stub, file_path)) for file_path in file_paths]
30+
31+
32+
def walk_file_paths(root_dir: str, ignore: str | None = None) -> list[str]:
33+
all_files = []
34+
for dirpath, _, filenames in os.walk(root_dir):
35+
for filename in filenames:
36+
rel_path = os.path.relpath(os.path.join(dirpath, filename), root_dir)
37+
if not (ignore and rel_path.startswith(ignore)):
38+
all_files.append(rel_path)
39+
return all_files
40+
41+
42+
# distinguish relative file paths from absolute/resolved file paths
43+
# relative file paths are used to find the save path
44+
# resolved file paths are what are used to load data

0 commit comments

Comments
 (0)