Commit 4ce1bdf
authored
[Quantization] Attention/ KV Cache Refactor (#1651)
## Purpose ##
* Support fully-expressive attention and kv cache quantization
* Support running kv cache quantization evals with hf transformers
* Resolves #1949
* Resolves #1928
```python3
recipe = QuantizationModifier(
config_groups={
"attention": QuantizationScheme(
targets=["LlamaAttention"],
input_activations=QuantizationArgs(
num_bits=8, type="float", strategy="tensor"
),
)
}
)
```
```json
{
"quantization_config": {
"config_groups": {
"group_0": {
"format": null,
"input_activations": {
"dynamic": false,
"num_bits": 8,
"observer": "minmax",
"strategy": "tensor",
"symmetric": true,
"type": "float"
},
"output_activations": null,
"targets": [
"LlamaAttention"
],
"weights": null
}
},
"format": "dense",
"ignore": [],
"kv_cache_scheme": {
"dynamic": false,
"group_size": null,
"num_bits": 8,
"observer": "minmax",
"strategy": "tensor",
"symmetric": true,
"type": "float"
},
"quant_method": "compressed-tensors",
"quantization_status": "frozen",
},
}
```
## Prerequisites ##
* Must be merged at the same time as
vllm-project/compressed-tensors#436
## Changes ##
* Replace hooks
* Remove `calibrate_kv_cache_input_hook`,
`calibrate_kv_cache_output_hook`, `initialize_quantized_kv_cache`
* Add `calibrate_query_hook` `calibrate_key_hook`,
`calibrate_value_hook`
* QuantizationMixin now initializes "q", "k", and "v" obsevers
([depending on the attached
submodules](https://github.com/vllm-project/llm-compressor/pull/1651/files#diff-33303ae48e185b2fbb14dc45c2052805837deb5723248367b9579321c4c4e974R263-R270))
and adds the appropriate hooks
* Miscellaneous
* Fix minor shape bug in `_flatten_attention`
* Add support for "attn_head" strategy in `_flatten_attention`
* Tests
* Removed old QuantizationKVCache tests (these classes are now tested
[here])(https://github.com/neuralmagic/compressed-tensors/pull/436/files#diff-6e33ff48047dc4f7c9d969293f87e32e4d5ec3f3e8b741ea757780c8c0aab775)
* Updated scale names to avoid using enum
* Avoid unnecessary tokenization to reduce runtime
## Testing ##
* Kv cache regression tests pass
* Able to quantize attention with scripts (will add to examples once
loadable in vllm)
* kylesayrs/Llama-3.2-1B-Instruct-attention-fp8-head
* kylesayrs/Llama-3.2-1B-Instruct-attention-nvfp4-head
* Nightly passes (in progress)
## Evaluation ##
<details><summary>eval.py</summary>
```python
import sys
import lm_eval
model_id = sys.argv[1]
print(model_id)
results = lm_eval.simple_evaluate(
# 3) hf serialized
model="hf",
model_args={
"pretrained": model_id,
"add_bos_token": False,
"dtype": "auto",
"device_map": "cuda",
#"max_length": 128000,
},
device="cuda",
# 3/)
#tasks=["gsm8k_platinum", "mmlu_llama", "longbench2_single"],
tasks=["gsm8k_platinum"],
batch_size=64,
apply_chat_template=True,
fewshot_as_multiturn=True,
)
print(model_id)
print(lm_eval.utils.make_table(results))
```
</details>
<details><summary>compress.py</summary>
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
# Select model and load it.
#model_id = "Qwen/Qwen2.5-14B-Instruct-1M"
model_id = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Select calibration dataset.
DATASET_ID = "ultrachat_200k"
DATASET_SPLIT = "train_sft"
# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Configure the quantization algorithm to run.
args = QuantizationArgs(
num_bits=8,
type="float",
strategy="attn_head",
symmetric=True,
observer="static_minmax",
)
recipe = QuantizationModifier(
# config_groups={
# "attention": QuantizationScheme(
# #targets=["Qwen2Attention"],
# targets=["LlamaAttention"],
# input_activations=args,
# )
# }
kv_cache_scheme=args,
)
# Apply algorithms.
oneshot(
model=model,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")
# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + f"-KV-FP8-{args.strategy}-{args.observer}"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
</details>
Model | GSM8K
-- | --
nm-testing/Llama-3.1-8B-Instruct | 0.8337
nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Tensor | 0.8271
nm-testing/Llama-3.1-8B-Instruct-KV-FP8-Head | 0.8354
nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Tensor | 0.8321
nm-testing/Llama-3.1-8B-Instruct-QKV-FP8-Head | 0.8238
---------
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>1 parent 6db28bc commit 4ce1bdf
File tree
10 files changed
+195
-557
lines changed- experimental
- src/llmcompressor
- modifiers
- quantization
- quantization
- utils
- observers
- tests/llmcompressor
- modifiers/calibration
- transformers/kv_cache
10 files changed
+195
-557
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | 3 | | |
5 | 4 | | |
This file was deleted.
0 commit comments