Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ vllmini/
- Implemented **Quantization** using `bitsandbytes` for 4-bit NF4 quantisation. ~~Still experimenting so it has a seperate branch~~ (it has been merged with main now).
- **Key insight**: 4-bit is a compressed encoding of the weights, we don't just truncate the weights to 4-bit (Turns out this is not correct, more on this later.)
- **Dequantisation Formula**:
$$ dequantized\ weight = codebook[4\ bit\ index] × scale + zero\ point $$
`` dequantized weight = codebook[4 bit index] × scale + zero point ``
- **Basic idea**: so instead of storing weights as FP16/BF16 (2 bytes per parameter) or FP32 (4 bytes) which is full precision, we store each weight as a 4-bit index (0–15). This 4-bit value points to a specific float in a shared codebook (typically 16 values). So 16 weights share the same 16-entry codebook, meaning you need only 0.25 bytes per weight + 16 bytes per block for scale/zero-point. This reduces memory by ~8x vs FP16 and ~16x vs FP32.

- **RoPE Sharing** : Optimised rotary embedding buffers to share vram across 32+ layers.
Expand Down
9 changes: 5 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def print_separator():
print("-" * 50)

@torch.inference_mode()
def benchmark(model_id:str, prompt:str, device:str):
print(f"Loading model {model_id} to {device}...")
model, config = load_hf_model(model_id, device=device)
def benchmark(model_id:str, prompt:str, device:str, quantize:bool):
print(f"Loading model {model_id} to {device} (quantize={quantize})...")
model, config = load_hf_model(model_id, device=device, quantize=quantize)
tokenizer = AutoTokenizer.from_pretrained(model_id)
sampler = Sampler(temperature=0.0) # Greedy for consistency

Expand Down Expand Up @@ -107,9 +107,10 @@ def parse_args():
parser.add_argument("--model-id", type=str, default=MODEL_ID, help="Model ID")
parser.add_argument("--prompt", type=str, default=PROMPT, help="Prompt")
parser.add_argument("--device", type=str, default=DEVICE, help="Device to use")
parser.add_argument("--quantize", "-q", action="store_true", help="Enable 4-bit NF4 quantization")

return parser.parse_args()

if __name__ == "__main__":
args = parse_args()
benchmark(args.model_id, args.prompt, args.device)
benchmark(args.model_id, args.prompt, args.device, args.quantize)
2 changes: 1 addition & 1 deletion models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn.functional as F
import torch.nn as nn
from models.llama import LlamaConfig, MLP, LlamaForCausalLM, TransformerBlock, RMSNorm, RotaryEmbedding
from models.attention import Attention as LlamaAttention, FlashAttention as LlamaFlashAttention, apply_rotary
from models.attention import FlashAttention as LlamaFlashAttention, apply_rotary

class QwenAttention(LlamaFlashAttention):
def __init__(self, config: LlamaConfig, rotary_emb: RotaryEmbedding):
Expand Down
Loading