-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathengine.py
More file actions
1185 lines (1036 loc) · 50 KB
/
engine.py
File metadata and controls
1185 lines (1036 loc) · 50 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
spectralquant/engine.py — SpectralQuantEngine: the main public API.
Drop-in replacement for standard attention. Follows the same interface
as turboquant-gpu (https://github.com/DevTechJr/turboquant-gpu).
Usage:
engine = SpectralQuantEngine(head_dim=128, total_bits=3, device="cuda")
engine.calibrate(model, tokenizer, texts) # 15 seconds, one-time
result = engine.generate(model, tok, "prompt") # one-call generation
engine.apply_to(model) # permanent monkey-patch
"""
from __future__ import annotations
import functools
import math
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
from spectralquant.calibrate import EigenspectralCalibrator, HeadCalibrationData
from spectralquant.kernels import (
compress_keys_spectral,
compress_values_spectral,
compress_values_waterfill,
fused_spectral_attention,
)
# Module-level Lloyd-Max codebook cache — avoids O(n_heads) scipy calls
_CODEBOOK_CACHE: Dict[Tuple[float, int], torch.Tensor] = {}
def _get_codebook(sigma: float, bits: int) -> torch.Tensor:
"""Cached Lloyd-Max codebook for N(0, sigma²).
The Lloyd-Max iteration at this resolution converges far before 200
iterations and is invariant to the 3rd decimal of sigma (codebooks for
sigma=0.123 and sigma=0.124 differ by < 1e-3 absolute), so we round the
cache key aggressively to keep this from dominating engine-build time
on production-scale models. With ~256 heads × ~5 d_eff × varying bit
allocations the naive 6-decimal cache would build ~4000+ codebooks
sequentially and take 10+ minutes; rounding to 2 decimals collapses
that to a few hundred unique entries (~30 seconds).
"""
# 2-decimal cache key + 50-iteration cap (matches Lloyd-Max convergence
# tail to <1e-4 for our use case)
sig_key = round(sigma, 2)
key = (sig_key, bits)
if key in _CODEBOOK_CACHE:
return _CODEBOOK_CACHE[key]
n = 1 << bits
try:
from scipy import integrate as _integrate
pdf = lambda x: (math.exp(-x*x/(2*sig_key*sig_key))
/ (math.sqrt(2*math.pi)*sig_key)) if sig_key > 0 else 0.0
lo, hi = -3.5*sig_key, 3.5*sig_key
cents = [lo + (hi-lo)*(i+0.5)/n for i in range(n)]
for _ in range(50):
bounds = [(cents[i]+cents[i+1])/2 for i in range(n-1)]
edges = [lo*3] + bounds + [hi*3]
new = []
for i in range(n):
a, b = edges[i], edges[i+1]
num, _ = _integrate.quad(lambda x: x*pdf(x), a, b, limit=20)
den, _ = _integrate.quad(pdf, a, b, limit=20)
new.append(num/den if den > 1e-15 else cents[i])
if max(abs(new[i]-cents[i]) for i in range(n)) < 1e-6:
break
cents = new
_CODEBOOK_CACHE[key] = torch.tensor(cents, dtype=torch.float32)
except ImportError:
_CODEBOOK_CACHE[key] = torch.linspace(-3.5*sig_key, 3.5*sig_key, n)
return _CODEBOOK_CACHE[key]
# ---------------------------------------------------------------------------
# Per-head engine (one per attention head)
# ---------------------------------------------------------------------------
@dataclass
class HeadEngine:
"""Compression engine for one (layer, head) pair.
Holds **separate** rotation matrices, effective dims, and codebooks for
keys and values. Keys and values have very different covariance
structures (especially for RoPE LLMs where K is post-RoPE and V is not),
so compressing both with the same eigenbasis would rotate V's variance
into directions the value codebooks weren't built for and destroy
reconstruction quality.
The public ``d_eff`` / ``eigenvectors`` attributes are aliases of the
key-side values to preserve the prior API (e.g. notebook code that reads
``he.d_eff`` or ``he.eigenvectors``).
"""
layer_idx: int
head_idx: int
head_dim: int
# Key-side eigenspectral data
d_eff: int # key d_eff (also exposed as `d_eff`)
eigenvectors: torch.Tensor # (D, D) — key eigenvectors
eigenvalues: torch.Tensor # (D,) — key eigenvalues
# Value-side eigenspectral data (separate basis & sigmas)
v_d_eff: int
v_eigenvectors: torch.Tensor # (D, D) — value eigenvectors
v_eigenvalues: torch.Tensor # (D,) — value eigenvalues
b_high: int # bits for semantic dims (MSE)
b_low: int # bits for *key* tail dims
S_sem: torch.Tensor # (d_eff, D) QJL projection
centroids_key_high: torch.Tensor
centroids_key_low: torch.Tensor # UNIT variance (per-dim sigma applied)
centroids_val_high: torch.Tensor
centroids_val_low: torch.Tensor # UNIT variance (per-dim sigma applied)
# Per-dim sigmas for the tail dims, used to rescale before/after
# quantization with the shared unit-variance tail codebook. Storing
# sigmas (D - d_eff floats per head) is cheaper than storing per-dim
# codebooks and is mathematically equivalent for nearest-centroid Lloyd-
# Max codebooks on Gaussian-ish data.
k_tail_sigmas: torch.Tensor = field(
default_factory=lambda: torch.empty(0, dtype=torch.float32)
)
v_tail_sigmas: torch.Tensor = field(
default_factory=lambda: torch.empty(0, dtype=torch.float32)
)
use_water_fill: bool = False
mse_bits_per_dim: List[int] = field(default_factory=list)
centroids_key_per_dim: List[torch.Tensor] = field(default_factory=list)
centroids_val_per_dim: List[torch.Tensor] = field(default_factory=list)
# Bits per dim for the *value* tail. -1 → fall back to ``self.b_low``
# (this preserves backward compatibility for callers that don't know
# about the asymmetric K/V split introduced when we decoupled value
# compression on RoPE LLMs).
b_low_val: int = -1
def to(self, device: str) -> "HeadEngine":
"""Move all tensors to device."""
self.eigenvectors = self.eigenvectors.to(device)
self.eigenvalues = self.eigenvalues.to(device)
self.v_eigenvectors = self.v_eigenvectors.to(device)
self.v_eigenvalues = self.v_eigenvalues.to(device)
self.S_sem = self.S_sem.to(device)
self.k_tail_sigmas = self.k_tail_sigmas.to(device)
self.v_tail_sigmas = self.v_tail_sigmas.to(device)
for attr in ("centroids_key_high", "centroids_key_low",
"centroids_val_high", "centroids_val_low"):
setattr(self, attr, getattr(self, attr).to(device))
self.centroids_key_per_dim = [c.to(device) for c in self.centroids_key_per_dim]
self.centroids_val_per_dim = [c.to(device) for c in self.centroids_val_per_dim]
return self
def compress_keys(self, K: torch.Tensor) -> Dict:
ts = self.k_tail_sigmas if self.k_tail_sigmas.numel() > 0 else None
# Pass the water-fill per-dim codebooks when we have them — otherwise
# the kernel falls back to the single shared `centroids_key_high`,
# which clips the dominant signal dim (σ₀ ≫ mean σ) and produces the
# K cosine ≈ 0.94 we saw before the wiring was complete.
per_dim = (
self.centroids_key_per_dim
if self.use_water_fill and self.centroids_key_per_dim
else None
)
return compress_keys_spectral(
K.view(-1, self.head_dim),
self.eigenvectors,
self.centroids_key_high,
self.centroids_key_low,
self.S_sem,
self.d_eff,
tail_sigmas=ts,
centroids_per_dim=per_dim,
)
def compress_values(self, V: torch.Tensor) -> torch.Tensor:
V_2d = V.view(-1, self.head_dim)
ts = self.v_tail_sigmas if self.v_tail_sigmas.numel() > 0 else None
# NB: values use their *own* eigenvectors and d_eff. Rotating V by
# the key eigenbasis (the previous behaviour) miscalibrates the value
# codebooks because V's variance distribution doesn't match K's, and
# produces unrecoverable reconstruction error that compounds through
# the model.
if self.use_water_fill and self.centroids_val_per_dim:
return compress_values_waterfill(
V_2d, self.v_eigenvectors,
self.centroids_val_per_dim, self.centroids_val_low,
self.v_d_eff,
tail_sigmas=ts,
)
return compress_values_spectral(
V_2d, self.v_eigenvectors,
self.centroids_val_high, self.centroids_val_low,
self.v_d_eff,
tail_sigmas=ts,
)
def attention(
self,
Q: torch.Tensor, # (..., S_Q, D)
compressed_K: Dict,
V_decompressed: torch.Tensor,
*,
use_qjl_correction: bool = False,
) -> torch.Tensor:
# Reshape to (B, H=1, S, D) for kernel
shape = Q.shape
Q_4d = Q.view(1, 1, -1, self.head_dim)
K_4d = compressed_K["k_mse"].view(1, 1, -1, self.head_dim)
V_4d = V_decompressed.view(1, 1, -1, self.head_dim)
sg_4d = compressed_K["signs"].view(1, 1, -1, self.d_eff)
rn_4d = compressed_K["r_norms"].view(1, 1, -1)
out = fused_spectral_attention(
Q_4d, K_4d, V_4d, sg_4d, rn_4d,
self.S_sem, self.d_eff,
use_qjl_correction=use_qjl_correction,
)
return out.view(*shape[:-2], shape[-2], self.head_dim)
def bytes_per_token(self, use_qjl: bool = False) -> Tuple[float, float]:
"""Return (key_bytes, val_bytes) per token.
Parameters
----------
use_qjl : bool
If True, includes the d_eff selective-QJL sign bits in the
per-token cost. When the engine has QJL disabled (the default)
those sign bits aren't needed on the wire at all, so we drop
them from the byte count.
"""
D = self.head_dim
# Key: MSE bits + (optional) selective QJL sign bits + r_norm fp16
key_sem_mse = (sum(self.mse_bits_per_dim)
if self.mse_bits_per_dim else self.d_eff * self.b_high)
qjl_bits = self.d_eff if use_qjl else 0
# 32 bits = r_norm (fp16) + 1 sign bit per dim reserved; matches the
# storage layout used by compress_keys regardless of QJL toggle.
key_bits = key_sem_mse + qjl_bits + (D - self.d_eff) * self.b_low + 32
# Value: water-filled per-dim bits if available, else flat b_high.
if self.centroids_val_per_dim:
val_sem = sum(
c.shape[0].bit_length() - 1 for c in self.centroids_val_per_dim
)
else:
val_sem = self.v_d_eff * self.b_high
b_low_v = self.b_low_val if self.b_low_val > 0 else self.b_low
val_bits = val_sem + (D - self.v_d_eff) * b_low_v
return key_bits / 8, val_bits / 8
# ---------------------------------------------------------------------------
# Main engine
# ---------------------------------------------------------------------------
class SpectralQuantEngine:
"""
SpectralQuant KV cache compression engine.
Drop-in replacement for TurboQuant. Compatible with any HuggingFace model.
Example:
engine = SpectralQuantEngine(head_dim=128, total_bits=3, device="cuda")
engine.calibrate(model, tokenizer, texts)
engine.apply_to(model)
# all model.generate() calls now use SpectralQuant
"""
def __init__(
self,
head_dim: int = 128,
total_bits: int = 3,
use_water_fill: bool = True,
use_qjl: bool = False,
d_eff_variance: float = 0.90,
noise_bits: Optional[int] = None,
value_noise_bits: Optional[int] = None,
device: str = "cuda",
) -> None:
"""
Parameters
----------
head_dim : int
Per-head dimension (D).
total_bits : int
Target average bits per element used to derive the bit-allocation.
use_water_fill : bool
If True, allocate MSE / value bits per signal dimension via the
water-filling heuristic over the eigenvalues.
use_qjl : bool
If True, the fused attention kernel adds the selective JL
sign-sketch correction term. Defaults to False because the
estimator variance for typical decoder-LLM ``d_eff`` (≈ 4) is
larger than the score it tries to correct and hurts softmax
quality. Leave on for architectures where ``d_eff`` is large
relative to ``head_dim`` (e.g. ViT/ESM heads).
d_eff_variance : float
Fraction of total eigenvalue energy the per-dim "signal" codebooks
must cover (default 0.90). This determines the **actual** d_eff
used for bit allocation, in place of the participation-ratio
d_eff diagnostic — the latter collapses to ~1 when a single
eigenvalue dominates a slow-decay tail (RoPE LLMs), which over-
concentrates bits on dim 0 and starves the position-distinguishing
tail dims. Higher values (e.g. 0.95) keep more dims in the
per-dim regime at the cost of marginally less compression.
noise_bits : Optional[int]
Override for the bit count assigned to KEY noise-tail dimensions.
When ``None`` (default), the tail uses ``round(total_bits)`` bits
so the engine reduces to a TurboQuant-style uniform allocation.
Set to ``1`` to recover the full SpectralQuant v2 architecture
where the tail keeps only sign information; combined with
``total_bits`` controlling the semantic budget, this is what
unlocks the ≥5× compression numbers from the paper. Typical
settings: ``noise_bits=1`` with ``total_bits=4`` or ``5`` to keep
high-precision semantic dims while collapsing the tail.
value_noise_bits : Optional[int]
Override for the bit count assigned to VALUE noise-tail dims
(defaults to ``noise_bits`` when ``None``). Values are not
position-encoded by RoPE and only affect the attention output
additively (Σ wᵢ · vᵢ) rather than multiplicatively, so they
tolerate substantially more aggressive compression than keys.
On RoPE LLMs (Mistral, Llama, Qwen) we recommend
``noise_bits=2`` (keep RoPE positional info) paired with
``value_noise_bits=1`` (collapse the V tail) — together this
roughly doubles the value-side compression while leaving keys
untouched, lifting overall ratios from ~5× to ~9×.
device : str
Torch device string.
"""
self.head_dim = head_dim
self.total_bits = total_bits
self.use_water_fill = use_water_fill
self.use_qjl = use_qjl
self.d_eff_variance = d_eff_variance
self.noise_bits = noise_bits
# Default value_noise_bits to noise_bits when not specified, so the
# legacy "set noise_bits to control both" usage keeps working.
self.value_noise_bits = (
value_noise_bits if value_noise_bits is not None else noise_bits
)
self.device = device
self._calibrator = EigenspectralCalibrator()
self._head_engines: Dict[Tuple[int, int], HeadEngine] = {}
self._calibrated = False
# ------------------------------------------------------------------
# Calibration
# ------------------------------------------------------------------
def calibrate(
self,
model: nn.Module,
tokenizer: Any,
calibration_texts: List[str],
n_samples: int = 512,
save_path: Optional[str] = None,
) -> Dict[str, Any]:
"""
Calibrate per-head eigenspectral structure. ~15 seconds on GPU.
Args:
model: Any HuggingFace model.
tokenizer: Matching tokenizer.
calibration_texts: Representative texts (diverse helps).
n_samples: How many texts to process.
save_path: If given, save calibration data here.
Returns:
Summary dict with d_eff stats.
"""
t0 = time.perf_counter()
self._calibrator.calibrate(
model, tokenizer, calibration_texts,
n_samples=n_samples, device=self.device,
)
elapsed = time.perf_counter() - t0
# Build per-head engines
self._head_engines.clear()
self._build_head_engines()
if save_path:
self._calibrator.save(save_path)
summary = self._calibrator.summary()
summary["calibration_time_s"] = round(elapsed, 1)
print(
f"[SpectralQuant] Calibrated {summary['n_heads']} heads in {elapsed:.1f}s. "
f"d_eff: mean={summary['d_eff_mean']:.1f} "
f"[{summary['d_eff_min']},{summary['d_eff_max']}]."
)
self._calibrated = True
return summary
def load_calibration(self, path: str) -> None:
"""Load pre-computed calibration data."""
self._calibrator.load(path)
self._build_head_engines()
self._calibrated = True
def save_calibration(self, path: str) -> None:
"""Persist calibration to disk for instant reuse next session."""
if not self._calibrated:
raise RuntimeError(
"Engine has no calibration to save — call .calibrate(...) first."
)
self._calibrator.save(path)
def collect_kv(
self,
model: nn.Module,
tokenizer: Any,
prompt: str,
) -> Dict[int, Dict[str, torch.Tensor]]:
"""Run one forward pass and return per-layer K/V the engine sees at inference.
For RoPE LLMs the captured K is *post-RoPE* (matching what the SQ
attention kernel actually compresses); for ViT/ESM/VideoMAE it is the
k_proj output (which equals the attention K for those models).
Use this when debugging compression quality — comparing the engine's
reconstruction against pre-RoPE K (via a naive linear hook) gives
misleadingly low cosine similarity because the eigenvectors were
calibrated on a different distribution.
Returns
-------
dict
``{layer_idx: {"k": Tensor(total_tokens, n_kv_heads, head_dim),
"v": Tensor(total_tokens, n_kv_heads, head_dim)}}``
"""
from spectralquant.calibrate import (
_model_uses_rope, _iter_attention_layers, _get_kv_dims,
_find_kv_linears,
)
dev = str(next(model.parameters()).device)
cfg = getattr(model, "config", None)
inputs = tokenizer(prompt, return_tensors="pt").to(dev)
store: Dict[int, Dict[str, torch.Tensor]] = {}
if _model_uses_rope(model):
# Post-RoPE path: intercept ALL_ATTENTION_FUNCTIONS.
import transformers.modeling_utils as _mu
reg = _mu.ALL_ATTENTION_FUNCTIONS
impl = getattr(cfg, "_attn_implementation", "sdpa") if cfg else "sdpa"
original_impl = reg[impl]
for layer_idx, attn in _iter_attention_layers(model):
attn._sq_collect_layer_idx = layer_idx
def collect_fn(module, query, key, value, *args, **kwargs):
li = getattr(module, "_sq_collect_layer_idx", None)
if li is not None and isinstance(key, torch.Tensor) and key.dim() == 4:
B, H, S, D = key.shape
k_2d = (key .detach().float()
.permute(0, 2, 1, 3).reshape(-1, H, D).cpu())
v_2d = (value.detach().float()
.permute(0, 2, 1, 3).reshape(-1, H, D).cpu())
store[li] = {"k": k_2d, "v": v_2d}
return original_impl(module, query, key, value, *args, **kwargs)
sentinel = "_sq_collect_backup_" + impl
reg[sentinel] = original_impl
reg[impl] = collect_fn
try:
model.eval()
with torch.no_grad():
model(**inputs, use_cache=True)
finally:
reg[impl] = reg.pop(sentinel, original_impl)
for _li, attn in _iter_attention_layers(model):
if hasattr(attn, "_sq_collect_layer_idx"):
delattr(attn, "_sq_collect_layer_idx")
return store
# Non-RoPE path: hook k_proj / v_proj outputs.
handles: list = []
for layer_idx, attn in _iter_attention_layers(model):
_, n_kv, head_dim = _get_kv_dims(attn, cfg)
k_lin, v_lin = _find_kv_linears(attn)
if k_lin is None:
continue
def make_hook(li, n_kv_l, hd, kind):
def hook(_mod, _inp, out):
t = out.detach().float().cpu()
if t.dim() == 3:
B, S, _ = t.shape
t = t.view(B, S, n_kv_l, hd).permute(0, 2, 1, 3).reshape(-1, n_kv_l, hd)
elif t.dim() == 2:
T = t.shape[0]
t = t.view(T, n_kv_l, hd)
store.setdefault(li, {})[kind] = t
return hook
handles.append(k_lin.register_forward_hook(make_hook(layer_idx, n_kv, head_dim, "k")))
handles.append(v_lin.register_forward_hook(make_hook(layer_idx, n_kv, head_dim, "v")))
try:
model.eval()
with torch.no_grad():
model(**inputs)
finally:
for h in handles:
h.remove()
return store
def _build_head_engines(self) -> None:
"""Build HeadEngine objects from calibration data.
Two things changed vs the earlier implementation:
1. **Variance-coverage d_eff.** The participation ratio in
``hcd.d_eff`` collapses toward 1 whenever a single eigenvalue
dominates a slow-decay tail (the typical post-RoPE LLM regime).
Bit-allocation drove from that number under-resourced the
position-distinguishing tail dims and produced repetitive
generation. We instead compute ``d_eff`` as the smallest prefix
of the eigenvalue spectrum that captures
``self.d_eff_variance`` of the energy. This naturally collapses
to small values on sharp cliffs (ViT, ESM) and expands on flat
tails (Mistral / Llama / Qwen).
2. **Separate K / V eigenstructure.** Value codebooks and the value
rotation matrix are derived from the **value** calibration, not
the key calibration. K and V have different principal axes,
especially for RoPE LLMs where only K is rotated by RoPE.
"""
from spectralquant._water_fill import water_fill_allocate # local import
var_thresh = float(self.d_eff_variance)
def _variance_d_eff(eigvals: torch.Tensor, D: int) -> int:
"""Smallest prefix length whose cumulative variance ≥ var_thresh."""
ev = eigvals.float().clamp(min=0)
total = ev.sum()
if total.item() <= 1e-12:
return 2
cum = ev.cumsum(0) / total
n = int((cum < var_thresh).sum().item()) + 1
return max(2, min(n, D - 2))
# Progress bar — codebook construction can take a minute on large
# models. tqdm is optional; falls back to a quiet enumerate if not
# installed.
key_heads = [h for h in self._calibrator.iter_heads()
if h.head_type == "key"]
try:
from tqdm.auto import tqdm
iterator = tqdm(key_heads, desc="[SQ] building head engines",
unit="head", leave=False)
except ImportError:
iterator = key_heads
for hcd in iterator:
val_hcd = self._calibrator.get(hcd.layer_idx, hcd.head_idx, "value")
if val_hcd is None:
# No V data — degrade gracefully by reusing K's stats. Better
# than crashing, but the user should re-run calibration.
val_hcd = hcd
D = hcd.head_dim
d_eff = _variance_d_eff(hcd.eigenvalues, D)
v_d_eff = _variance_d_eff(val_hcd.eigenvalues, D)
# Per-dim sigmas for the tail dims. We rescale each tail dim by
# its own sigma before/after the shared codebook (see compress_*
# kernels) — equivalent to per-dim codebooks across the tail at
# zero extra storage. This eliminates the clipping bias that
# appears when the shared tail codebook is built for the mean
# tail variance but some dims live an order of magnitude above
# that mean (the post-RoPE LLM regime).
tail_floor = 1.0 / math.sqrt(D)
k_tail_sigmas = (
hcd.eigenvalues[d_eff:]
.float().clamp(min=tail_floor ** 2).sqrt()
)
v_tail_sigmas = (
val_hcd.eigenvalues[v_d_eff:]
.float().clamp(min=tail_floor ** 2).sqrt()
)
# The legacy non-water-fill sigmas (mean over the tail dims) are
# kept around mostly as a fallback when waterfill isn't used.
k_sigma_high = float(
hcd.eigenvalues[:d_eff].mean().clamp(min=1e-8).sqrt().item()
)
v_sigma_high = float(
val_hcd.eigenvalues[:v_d_eff].mean().clamp(min=1e-8).sqrt().item()
)
# Bit allocation.
#
# With `noise_bits` set, the user is opting into the asymmetric
# SQ v2 architecture:
# * semantic dims keep `total_bits` MSE bits each (→ b_high
# = total_bits + 1 because the storage layout reserves
# one slot for the QJL sign bit even when QJL is off);
# * key tail dims drop to `noise_bits` per dim;
# * value tail dims drop to `value_noise_bits` (defaults to
# `noise_bits`). Values are typically more compressible
# than keys because they're not RoPE-encoded — see the
# `value_noise_bits` docstring.
# Without `noise_bits`, the engine reduces to the legacy
# TurboQuant-style uniform allocation where every dim gets
# ~total_bits bits and `b_low ≈ total_bits`.
if self.noise_bits is not None:
b_low = max(1, int(self.noise_bits))
b_low_val = max(1, int(self.value_noise_bits))
b_high = int(self.total_bits) + 1
else:
b_low = max(1, round(self.total_bits - d_eff / D))
b_low_val = b_low
b_high = b_low + 1
mse_bits_high = max(b_high - 1, 1) # -1 for the QJL sign bit slot
# Shared per-dim codebooks: signal dims keep their per-dim sigma
# codebook (water-fill), tail dims use a UNIT-variance codebook
# together with the per-dim sigma rescaling above. Building the
# tail codebook at sigma=1 lets us reuse it across all heads.
cent_key_high = _get_codebook(k_sigma_high, mse_bits_high).to(self.device)
cent_key_low = _get_codebook(1.0, b_low).to(self.device)
cent_val_high = _get_codebook(v_sigma_high, b_high).to(self.device)
cent_val_low = _get_codebook(1.0, b_low_val).to(self.device)
# Water-fill per-dim codebooks (v2): allocate bits proportional to
# each axis's variance. Keys and values get independent
# allocations because their eigenvalue spectra differ.
mse_bits_per_dim: List[int] = []
cent_key_per_dim: List[torch.Tensor] = []
cent_val_per_dim: List[torch.Tensor] = []
if self.use_water_fill and d_eff > 1:
sem_budget = d_eff * mse_bits_high
mse_bits_per_dim = water_fill_allocate(
hcd.eigenvalues, d_eff, sem_budget,
min_bits=1, max_bits=mse_bits_high + 2,
)
for i in range(d_eff):
sigma_i = float(hcd.eigenvalues[i].clamp(min=1e-8).sqrt().item())
cent_key_per_dim.append(
_get_codebook(sigma_i, mse_bits_per_dim[i]).to(self.device)
)
if self.use_water_fill and v_d_eff > 1:
val_budget = v_d_eff * b_high
val_bits_per_dim = water_fill_allocate(
val_hcd.eigenvalues, v_d_eff, val_budget,
min_bits=1, max_bits=b_high + 2,
)
for i in range(v_d_eff):
sigma_i = float(
val_hcd.eigenvalues[i].clamp(min=1e-8).sqrt().item()
)
cent_val_per_dim.append(
_get_codebook(sigma_i, val_bits_per_dim[i]).to(self.device)
)
# QJL projection matrix: first d_eff rows of a random (D×D) matrix.
# Shape: (d_eff, D) so that both compress_keys (signs = residual @ S.T)
# and fused_attention (q_proj = Q @ S.T) are consistent.
gen = torch.Generator()
gen.manual_seed(42 + hcd.layer_idx * 1000 + hcd.head_idx)
S_full = torch.randn(D, D, generator=gen, dtype=torch.float16)
S_sem = S_full[:d_eff, :].to(self.device) # (d_eff, D)
he = HeadEngine(
layer_idx=hcd.layer_idx,
head_idx=hcd.head_idx,
head_dim=D,
d_eff=d_eff,
eigenvectors=hcd.eigenvectors.to(self.device),
eigenvalues=hcd.eigenvalues.to(self.device),
v_d_eff=v_d_eff,
v_eigenvectors=val_hcd.eigenvectors.to(self.device),
v_eigenvalues=val_hcd.eigenvalues.to(self.device),
b_high=b_high,
b_low=b_low,
b_low_val=b_low_val,
S_sem=S_sem,
centroids_key_high=cent_key_high,
centroids_key_low=cent_key_low,
centroids_val_high=cent_val_high,
centroids_val_low=cent_val_low,
k_tail_sigmas=k_tail_sigmas.to(self.device),
v_tail_sigmas=v_tail_sigmas.to(self.device),
use_water_fill=self.use_water_fill,
mse_bits_per_dim=mse_bits_per_dim,
centroids_key_per_dim=cent_key_per_dim,
centroids_val_per_dim=cent_val_per_dim,
)
self._head_engines[(hcd.layer_idx, hcd.head_idx)] = he
# ------------------------------------------------------------------
# Compression API (mirror of turboquant-gpu)
# ------------------------------------------------------------------
def compress_kv_cache(
self, past_key_values: Any
) -> List[Tuple[Dict, torch.Tensor]]:
"""
Compress an existing KV cache (output of model(..., use_cache=True)).
Returns a list of (compressed_k, v_decompressed) tuples, one per layer.
"""
result = []
kv_list = self._to_list(past_key_values)
for layer_idx, (K, V) in enumerate(kv_list):
layer_result = []
n_kv_heads = K.shape[1]
for head_idx in range(n_kv_heads):
he = self._head_engines.get((layer_idx, head_idx))
if he is None:
layer_result.append((None, V[:, head_idx]))
continue
k_h = K[:, head_idx].squeeze(0).half()
v_h = V[:, head_idx].squeeze(0).half()
ck = he.compress_keys(k_h)
vd = he.compress_values(v_h)
layer_result.append((ck, vd))
result.append(layer_result)
return result
def compression_stats(self, past_key_values: Any = None) -> Dict[str, float]:
"""Return compression ratio and bytes-per-token stats."""
if not self._head_engines:
return {"error": "not calibrated"}
sample = next(iter(self._head_engines.values()))
D = sample.head_dim
key_b, val_b = sample.bytes_per_token(use_qjl=self.use_qjl)
fp16_b = D * 2 # FP16 = 2 bytes per element per K or V
tq_b = D * 2 / 8 + D * 1 / 8 + 4 # TurboQuant: 2-bit MSE + 1-bit QJL + norms
return {
"fp16_bytes_per_token": fp16_b,
"sq_key_bytes": round(key_b, 2),
"sq_val_bytes": round(val_b, 2),
"sq_ratio": round(fp16_b / ((key_b + val_b) / 2), 2),
"tq_ratio": round(fp16_b / tq_b, 2),
"sq_vs_tq_ratio": round(tq_b / ((key_b + val_b) / 2), 3),
"d_eff": sample.d_eff,
}
# ------------------------------------------------------------------
# Generation API (mirror of turboquant-gpu)
# ------------------------------------------------------------------
def generate(
self,
model: nn.Module,
tokenizer: Any,
prompt: str,
max_new_tokens: int = 200,
**generation_kwargs,
) -> Dict[str, Any]:
"""
One-call compressed generation. Calibrate first.
Returns dict with "text", "tokens", "stats".
"""
if not self._calibrated:
raise RuntimeError("Call engine.calibrate() first.")
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
# Monkey-patch this model temporarily
_original, _patched = self._patch(model)
t0 = time.perf_counter()
with torch.no_grad():
output_ids = model.generate(
**inputs, max_new_tokens=max_new_tokens, **generation_kwargs
)
elapsed = time.perf_counter() - t0
self._unpatch(model, _original, _patched)
n_tokens = output_ids.shape[-1] - inputs["input_ids"].shape[-1]
text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True)
stats = self.compression_stats(None)
stats["tokens_per_second"] = round(n_tokens / elapsed, 1)
return {"text": text, "tokens": n_tokens, "stats": stats}
# ------------------------------------------------------------------
# Cache-level integration (preferred for live LLM inference).
#
# These methods compress the prefill KV cache once and let HuggingFace
# run stock SDPA on the dequantised reconstruction during decode.
# See `spectralquant.integrations.dynamic_cache` for the full design
# rationale. Compared to the attention-patch path:
# * decode runs at full FP16 SDPA speed (no O(T²) recompression);
# * RoPE is exact for every newly generated token;
# * prompts can be safely compressed at aggressive bit budgets
# (e.g. ``noise_bits=1``) without breaking greedy decoding,
# because the noise lives only on the static prefix.
# ------------------------------------------------------------------
def prefill_compress(
self,
model: nn.Module,
tokenizer: Any,
prompt: str,
*,
device: Optional[str] = None,
add_special_tokens: bool = True,
) -> Dict[str, Any]:
"""Run an FP16 prefill, compress the resulting KV cache, return a
DynamicCache + stats. See
:func:`spectralquant.integrations.dynamic_cache.prefill_compress`."""
from spectralquant.integrations.dynamic_cache import prefill_compress
return prefill_compress(
self, model, tokenizer, prompt,
device=device, add_special_tokens=add_special_tokens,
)
def generate_compressed(
self,
model: nn.Module,
tokenizer: Any,
prompt: str,
*,
max_new_tokens: int = 128,
do_sample: bool = False,
temperature: float = 1.0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
device: Optional[str] = None,
add_special_tokens: bool = True,
) -> Dict[str, Any]:
"""End-to-end compressed generation with cache-level integration.
See
:func:`spectralquant.integrations.dynamic_cache.generate_with_compressed_cache`."""
from spectralquant.integrations.dynamic_cache import (
generate_with_compressed_cache,
)
return generate_with_compressed_cache(
self, model, tokenizer, prompt,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
device=device,
add_special_tokens=add_special_tokens,
)
# ------------------------------------------------------------------
# Permanent monkey-patch (apply_to)
# ------------------------------------------------------------------
def apply_to(self, model: nn.Module) -> nn.Module:
"""
Permanently patch model's attention to use SpectralQuant.
Future calls to model.generate() or model.forward() use SQ.
Reversible via :meth:`remove_from` — the original SDPA function is
stashed in the global attention registry and restored on removal.
"""
if not self._calibrated:
raise RuntimeError("Call engine.calibrate() first.")
from spectralquant.integrations.huggingface import (
build_sq_attention_fn, _install_sq_in_attention_registry,
)
sq_fn = build_sq_attention_fn(self)
_install_sq_in_attention_registry(sq_fn)
# Tag each LLM attention module with its layer index so the wrapper
# knows which calibration entry to use.
from spectralquant.calibrate import _iter_attention_layers
for layer_idx, attn in _iter_attention_layers(model):
attn._sq_layer_idx = layer_idx
# Additionally patch ViT/BERT/ESM-style modules whose attention is
# implemented as a custom Module rather than dispatching through
# ALL_ATTENTION_FUNCTIONS.
for layer_idx, attn in self._iter_attn_layers(model):
self._patch_module(attn, layer_idx, sq_fn)
model._sq_engine = self
return model
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _to_list(past_key_values: Any):
"""Normalise past_key_values to list[(K, V)] regardless of HF version.
Handles four formats produced by different transformers releases:
1. DynamicCache (≥4.36): populated key_cache / value_cache lists.
2. Any Cache subclass: to_legacy_cache() conversion method.
3. Legacy tuple-of-tuples (< 4.36): direct list() cast.
4. Iterable of (K, V) pairs: same as (3).
"""
# Format 1: DynamicCache — key_cache/value_cache as populated lists.
# Use isinstance(list) + len check rather than `is not None` because
# some versions initialise the attribute to None or [] before the first
# forward pass, and `hasattr` returns True in both those cases.
kc = getattr(past_key_values, "key_cache", None)
vc = getattr(past_key_values, "value_cache", None)
if isinstance(kc, list) and len(kc) > 0:
return [(kc[i], vc[i]) for i in range(len(kc))]
# Format 2: any Cache subclass with to_legacy_cache()
if hasattr(past_key_values, "to_legacy_cache"):
legacy = past_key_values.to_legacy_cache()
if legacy:
return list(legacy)
# Format 3/4: legacy tuple-of-tuples or any iterable of (K, V) pairs
return list(past_key_values)
def _iter_attn_layers(self, model):
from spectralquant.calibrate import _iter_attention_layers
return _iter_attention_layers(model)
# ------------------------------------------------------------------
# Context manager & permanent remove
# ------------------------------------------------------------------
@contextmanager
def applied_to(self, model: nn.Module) -> Iterator[nn.Module]:
"""
Context manager: apply SpectralQuant for the duration, then remove.
Example::
with engine.applied_to(model):
outputs = model.generate(**inputs, max_new_tokens=200)
# model is fully restored here
"""
self.apply_to(model)
try:
yield model
finally:
self.remove_from(model)
def remove_from(self, model: nn.Module) -> nn.Module:
"""
Undo all SpectralQuant patches previously applied by apply_to().
Restores original attention module forwards and restores the original
sdpa entry in ALL_ATTENTION_FUNCTIONS. Safe to call multiple times.
"""
from spectralquant.calibrate import _iter_attention_layers
from spectralquant.integrations.huggingface import (
uninstall_sq_from_attention_registry,
)
for _layer_idx, attn in _iter_attention_layers(model):
if hasattr(attn, "_sq_original_forward"):
attn.forward = attn._sq_original_forward
del attn._sq_original_forward
for attr in ("_sq_patched", "_sq_layer_idx"):
if hasattr(attn, attr):
delattr(attn, attr)
uninstall_sq_from_attention_registry()
if hasattr(model, "_sq_engine"):
del model._sq_engine
return model
# ------------------------------------------------------------------
# Internal: temporary patch used by generate()
# ------------------------------------------------------------------
def _patch(self, model: nn.Module):
"""Apply patches and record undo info. Returns (originals_map, {})."""
from spectralquant.integrations.huggingface import apply_spectralquant_to_model
apply_spectralquant_to_model(model, self)
# Also patch vision/protein-style modules
for layer_idx, attn in self._iter_attn_layers(model):
if not getattr(attn, "_sq_patched", False):
self._patch_module(attn, layer_idx)
originals: Dict[int, Tuple[nn.Module, Any]] = {}
from spectralquant.calibrate import _iter_attention_layers
for _li, attn in _iter_attention_layers(model):
if hasattr(attn, "_sq_original_forward"):
originals[id(attn)] = (attn, attn._sq_original_forward)
return originals, {}
def _unpatch(self, model: nn.Module, orig, patched) -> None:
"""Undo temporary patches applied by _patch()."""
self.remove_from(model)