diff --git a/README.md b/README.md index 7f8ce4e..9058313 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ vllmini/ - Implemented **Quantization** using `bitsandbytes` for 4-bit NF4 quantisation. ~~Still experimenting so it has a seperate branch~~ (it has been merged with main now). - **Key insight**: 4-bit is a compressed encoding of the weights, we don't just truncate the weights to 4-bit (Turns out this is not correct, more on this later.) - **Dequantisation Formula**: - $$ dequantized\ weight = codebook[4\ bit\ index] × scale + zero\ point $$ + `` dequantized weight = codebook[4 bit index] × scale + zero point `` - **Basic idea**: so instead of storing weights as FP16/BF16 (2 bytes per parameter) or FP32 (4 bytes) which is full precision, we store each weight as a 4-bit index (0–15). This 4-bit value points to a specific float in a shared codebook (typically 16 values). So 16 weights share the same 16-entry codebook, meaning you need only 0.25 bytes per weight + 16 bytes per block for scale/zero-point. This reduces memory by ~8x vs FP16 and ~16x vs FP32. - **RoPE Sharing** : Optimised rotary embedding buffers to share vram across 32+ layers. diff --git a/benchmark.py b/benchmark.py index 57753b8..1ccde00 100644 --- a/benchmark.py +++ b/benchmark.py @@ -17,9 +17,9 @@ def print_separator(): print("-" * 50) @torch.inference_mode() -def benchmark(model_id:str, prompt:str, device:str): - print(f"Loading model {model_id} to {device}...") - model, config = load_hf_model(model_id, device=device) +def benchmark(model_id:str, prompt:str, device:str, quantize:bool): + print(f"Loading model {model_id} to {device} (quantize={quantize})...") + model, config = load_hf_model(model_id, device=device, quantize=quantize) tokenizer = AutoTokenizer.from_pretrained(model_id) sampler = Sampler(temperature=0.0) # Greedy for consistency @@ -107,9 +107,10 @@ def parse_args(): parser.add_argument("--model-id", type=str, default=MODEL_ID, help="Model ID") parser.add_argument("--prompt", type=str, default=PROMPT, help="Prompt") parser.add_argument("--device", type=str, default=DEVICE, help="Device to use") + parser.add_argument("--quantize", "-q", action="store_true", help="Enable 4-bit NF4 quantization") return parser.parse_args() if __name__ == "__main__": args = parse_args() - benchmark(args.model_id, args.prompt, args.device) + benchmark(args.model_id, args.prompt, args.device, args.quantize) diff --git a/models/qwen3.py b/models/qwen3.py index 3ec7b26..7e0707c 100644 --- a/models/qwen3.py +++ b/models/qwen3.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import torch.nn as nn from models.llama import LlamaConfig, MLP, LlamaForCausalLM, TransformerBlock, RMSNorm, RotaryEmbedding -from models.attention import Attention as LlamaAttention, FlashAttention as LlamaFlashAttention, apply_rotary +from models.attention import FlashAttention as LlamaFlashAttention, apply_rotary class QwenAttention(LlamaFlashAttention): def __init__(self, config: LlamaConfig, rotary_emb: RotaryEmbedding):