From 99910fbd3a59624fd70d331aeb34a8b05147e782 Mon Sep 17 00:00:00 2001
From: Kunal Mansukhani
Date: Thu, 26 Mar 2026 19:23:51 -0400
Subject: [PATCH 1/5] Add turbo quant attention
---
.../83_turboquant_attention/challenge.html | 101 ++++++
.../83_turboquant_attention/challenge.py | 308 ++++++++++++++++++
.../starter/starter.cu | 5 +
.../starter/starter.cute.py | 18 +
.../starter/starter.jax.py | 18 +
.../starter/starter.mojo | 9 +
.../starter/starter.pytorch.py | 16 +
.../starter/starter.triton.py | 18 +
8 files changed, 493 insertions(+)
create mode 100644 challenges/medium/83_turboquant_attention/challenge.html
create mode 100644 challenges/medium/83_turboquant_attention/challenge.py
create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.cu
create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.cute.py
create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.jax.py
create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.mojo
create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.pytorch.py
create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.triton.py
diff --git a/challenges/medium/83_turboquant_attention/challenge.html b/challenges/medium/83_turboquant_attention/challenge.html
new file mode 100644
index 00000000..37433e10
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/challenge.html
@@ -0,0 +1,101 @@
+
+ Implement attention score computation against a TurboQuant-compressed KV cache. In modern LLM inference, key vectors are compressed using TurboQuant: each key is first rotated by an orthogonal matrix \(\Pi\), then each coordinate is scalar-quantized to the nearest centroid in a codebook, storing only the uint8 index. Given B query vectors of dimension D, S quantized key vectors (stored as codebook indices), the rotation matrix \(\Pi\), and a sorted codebook of C centroids, compute the attention scores matrix.
+
+
+ For each query \(Q_i\) and quantized key \(K_j\), compute:
+ \[
+ \text{scores}_{i,j} = Q_i \cdot \text{dequant}(K_j)
+ \]
+ where dequantization reconstructs the key by looking up centroids and rotating back:
+ \[
+ \text{dequant}(K_j) = [\text{codebook}[K_j[0]],\; \text{codebook}[K_j[1]],\; \ldots,\; \text{codebook}[K_j[D{-}1]]] \;\cdot\; \Pi
+ \]
+
+
+
+
+ Dequantization + Dot Product Pipeline
+
+ K_idx
+ S x D (uint8)
+
+
+ Lookup
+ codebook[idx]
+
+
+ Rotate back
+ K_deq = ... * Pi
+
+
+ Dot product
+ Q * K_deq^T
+
+ scores
+ Hint: rotating queries once avoids per-key rotation
+
+
+
+
+
+
+
+Implementation Requirements
+
+ The solve function signature must remain unchanged.
+ Use only native features (no external libraries).
+ \(\Pi\) is an orthogonal matrix (\(\Pi^T \Pi = I\)), stored row-major as float32.
+ The codebook is sorted in ascending order (float32).
+ K_idx contains uint8 indices in \([0, C)\).
+ Store the result in scores as float32.
+
+
+Example
+
+Input:
+Query matrix \(Q\) (\(B=2, D=2\)):
+\[
+\begin{bmatrix}
+1.0 & 0.0 \\
+0.0 & 1.0
+\end{bmatrix}
+\]
+Quantized key indices \(K_\text{idx}\) (\(S=3, D=2\), uint8):
+\[
+\begin{bmatrix}
+0 & 3 \\
+1 & 2 \\
+3 & 0
+\end{bmatrix}
+\]
+Rotation matrix \(\Pi = I_{2 \times 2}\) (identity)
+Codebook (\(C=4\)): \([-0.75,\; -0.25,\; 0.25,\; 0.75]\)
+
+
+Dequantized keys: look up centroids for each index, then multiply by \(\Pi\) (identity here):
+\[
+K_\text{deq} = \begin{bmatrix}
+-0.75 & 0.75 \\
+-0.25 & 0.25 \\
+0.75 & -0.75
+\end{bmatrix}
+\]
+Output scores (\(B=2, S=3\)):
+\[
+\text{scores} = Q \cdot K_\text{deq}^T = \begin{bmatrix}
+-0.75 & -0.25 & 0.75 \\
+0.75 & 0.25 & -0.75
+\end{bmatrix}
+\]
+
+
+Constraints
+
+ 1 ≤ B ≤ 32
+ 1 ≤ S ≤ 65,536
+ 1 ≤ D ≤ 256
+ 2 ≤ C ≤ 256
+ \(\Pi\) is an orthogonal matrix, codebook is sorted ascending
+ All floating-point values are float32; key indices are uint8
+ Performance is measured with B = 32, S = 32,768, D = 128, C = 16
+
diff --git a/challenges/medium/83_turboquant_attention/challenge.py b/challenges/medium/83_turboquant_attention/challenge.py
new file mode 100644
index 00000000..f6188e6c
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/challenge.py
@@ -0,0 +1,308 @@
+import ctypes
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="TurboQuant KV Cache Attention",
+ atol=1e-3,
+ rtol=1e-3,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ Pi: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+ ):
+ assert Q.shape == (B, D)
+ assert K_idx.shape == (S, D)
+ assert Pi.shape == (D, D)
+ assert codebook.shape == (C,)
+ assert scores.shape == (B, S)
+ assert Q.dtype == torch.float32
+ assert K_idx.dtype == torch.uint8
+ assert Pi.dtype == torch.float32
+ assert codebook.dtype == torch.float32
+ assert scores.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K_idx.device.type == "cuda"
+ assert Pi.device.type == "cuda"
+ assert codebook.device.type == "cuda"
+ assert scores.device.type == "cuda"
+
+ # Dequantize keys: lookup centroids then rotate back
+ K_centroids = codebook[K_idx.long()] # S x D
+ K_deq = K_centroids @ Pi # S x D
+
+ # Compute attention scores
+ scores.copy_(Q @ K_deq.T) # B x S
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"),
+ "Pi": (ctypes.POINTER(ctypes.c_float), "in"),
+ "codebook": (ctypes.POINTER(ctypes.c_float), "in"),
+ "scores": (ctypes.POINTER(ctypes.c_float), "out"),
+ "B": (ctypes.c_int, "in"),
+ "S": (ctypes.c_int, "in"),
+ "D": (ctypes.c_int, "in"),
+ "C": (ctypes.c_int, "in"),
+ }
+
+ def _make_rotation(self, D):
+ G = torch.randn(D, D, device="cuda")
+ Q, _ = torch.linalg.qr(G)
+ return Q
+
+ def _make_codebook(self, C, scale=1.0):
+ return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32)
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ B, S, D, C = 2, 3, 2, 4
+ Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device="cuda", dtype=torch.float32)
+ K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device="cuda", dtype=torch.uint8)
+ Pi = torch.eye(D, device="cuda", dtype=torch.float32)
+ codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device="cuda", dtype=torch.float32)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ return {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ tests = []
+
+ # Edge case: single query, single key, D=1, C=2
+ B, S, D, C = 1, 1, 1, 2
+ Q = torch.tensor([[0.5]], device="cuda", dtype=torch.float32)
+ K_idx = torch.tensor([[1]], device="cuda", dtype=torch.uint8)
+ Pi = torch.eye(D, device="cuda", dtype=torch.float32)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Edge case: zeros query
+ B, S, D, C = 2, 3, 4, 4
+ Q = torch.zeros(B, D, device="cuda", dtype=torch.float32)
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Edge case: negative query values
+ B, S, D, C = 2, 4, 4, 4
+ Q = torch.tensor(
+ [[-0.5, -0.3, -0.8, -0.1], [-1.0, -0.5, -0.2, -0.9]],
+ device="cuda",
+ dtype=torch.float32,
+ )
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Power-of-2: B=4, S=16, D=32, C=8
+ B, S, D, C = 4, 16, 32, 8
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.5
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C, scale=1.5)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Power-of-2: B=8, S=64, D=64, C=16
+ B, S, D, C = 8, 64, 64, 16
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Power-of-2: B=16, S=128, D=128, C=16
+ B, S, D, C = 16, 128, 128, 16
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Non-power-of-2: B=3, S=30, D=50, C=8
+ B, S, D, C = 3, 30, 50, 8
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.4
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Non-power-of-2: B=7, S=255, D=100, C=16
+ B, S, D, C = 7, 255, 100, 16
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.6
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C, scale=1.5)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ # Realistic: B=16, S=4096, D=128, C=16
+ B, S, D, C = 16, 4096, 128, 16
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ tests.append(
+ {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+ )
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ B, S, D, C = 32, 32768, 128, 16
+ Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
+ K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
+ Pi = self._make_rotation(D)
+ codebook = self._make_codebook(C)
+ scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
+ return {
+ "Q": Q,
+ "K_idx": K_idx,
+ "Pi": Pi,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cu b/challenges/medium/83_turboquant_attention/starter/starter.cu
new file mode 100644
index 00000000..be318f3c
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/starter/starter.cu
@@ -0,0 +1,5 @@
+#include
+
+// Q, K_idx, Pi, codebook, scores are device pointers
+extern "C" void solve(const float* Q, const unsigned char* K_idx, const float* Pi,
+ const float* codebook, float* scores, int B, int S, int D, int C) {}
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cute.py b/challenges/medium/83_turboquant_attention/starter/starter.cute.py
new file mode 100644
index 00000000..ebd728fa
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/starter/starter.cute.py
@@ -0,0 +1,18 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K_idx, Pi, codebook, scores are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K_idx: cute.Tensor,
+ Pi: cute.Tensor,
+ codebook: cute.Tensor,
+ scores: cute.Tensor,
+ B: cute.Int32,
+ S: cute.Int32,
+ D: cute.Int32,
+ C: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.jax.py b/challenges/medium/83_turboquant_attention/starter/starter.jax.py
new file mode 100644
index 00000000..3c38dbb5
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/starter/starter.jax.py
@@ -0,0 +1,18 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K_idx, Pi, codebook are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array,
+ K_idx: jax.Array,
+ Pi: jax.Array,
+ codebook: jax.Array,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.mojo b/challenges/medium/83_turboquant_attention/starter/starter.mojo
new file mode 100644
index 00000000..353883cd
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/starter/starter.mojo
@@ -0,0 +1,9 @@
+from gpu.host import DeviceContext
+from gpu.id import block_dim, block_idx, thread_idx
+from memory import UnsafePointer
+from math import ceildiv
+
+# Q, K_idx, Pi, codebook, scores are device pointers
+@export
+def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], Pi: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32):
+ pass
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py b/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py
new file mode 100644
index 00000000..a78ee78c
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py
@@ -0,0 +1,16 @@
+import torch
+
+
+# Q, K_idx, Pi, codebook, scores are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ Pi: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+):
+ pass
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.triton.py b/challenges/medium/83_turboquant_attention/starter/starter.triton.py
new file mode 100644
index 00000000..f21c5eea
--- /dev/null
+++ b/challenges/medium/83_turboquant_attention/starter/starter.triton.py
@@ -0,0 +1,18 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# q, k_idx, pi, codebook, scores are tensors on the GPU
+def solve(
+ q: torch.Tensor,
+ k_idx: torch.Tensor,
+ pi: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+):
+ pass
From 89536ca7e866414da2104beae01b6ed357872853 Mon Sep 17 00:00:00 2001
From: Kunal Mansukhani
Date: Thu, 26 Mar 2026 20:37:29 -0400
Subject: [PATCH 2/5] Make spec better
---
.../83_turboquant_attention/challenge.html | 65 ++++++-------------
1 file changed, 20 insertions(+), 45 deletions(-)
diff --git a/challenges/medium/83_turboquant_attention/challenge.html b/challenges/medium/83_turboquant_attention/challenge.html
index 37433e10..91f88bd3 100644
--- a/challenges/medium/83_turboquant_attention/challenge.html
+++ b/challenges/medium/83_turboquant_attention/challenge.html
@@ -1,52 +1,24 @@
- Implement attention score computation against a TurboQuant-compressed KV cache. In modern LLM inference, key vectors are compressed using TurboQuant: each key is first rotated by an orthogonal matrix \(\Pi\), then each coordinate is scalar-quantized to the nearest centroid in a codebook, storing only the uint8 index. Given B query vectors of dimension D, S quantized key vectors (stored as codebook indices), the rotation matrix \(\Pi\), and a sorted codebook of C centroids, compute the attention scores matrix.
+ Implement attention score computation against a quantized KV cache. During LLM inference, the KV cache can dominate memory. TurboQuant addresses this by compressing each key vector down to uint8 codebook indices, reducing memory by up to 4×. Your task is to compute attention scores between full-precision queries and these compressed keys.
+
- For each query \(Q_i\) and quantized key \(K_j\), compute:
- \[
- \text{scores}_{i,j} = Q_i \cdot \text{dequant}(K_j)
- \]
- where dequantization reconstructs the key by looking up centroids and rotating back:
- \[
- \text{dequant}(K_j) = [\text{codebook}[K_j[0]],\; \text{codebook}[K_j[1]],\; \ldots,\; \text{codebook}[K_j[D{-}1]]] \;\cdot\; \Pi
- \]
+ How TurboQuant encodes keys (already done for you): each key vector is multiplied by a random orthogonal matrix \(\Pi\), then each coordinate is replaced by the index of its nearest centroid in a codebook.
-
-
- Dequantization + Dot Product Pipeline
-
- K_idx
- S x D (uint8)
-
-
- Lookup
- codebook[idx]
-
-
- Rotate back
- K_deq = ... * Pi
-
-
- Dot product
- Q * K_deq^T
-
- scores
- Hint: rotating queries once avoids per-key rotation
-
-
-
-
-
-
+
+ What you need to compute : given queries \(Q\) and quantized key indices \(K_\text{idx}\), dequantize each key and compute dot products:
+
+
+ Lookup : replace each index with its codebook value, producing a centroid vector \(\tilde{Y}_j\)
+ Rotate back : recover the approximate key \(\tilde{K}_j = \tilde{Y}_j \cdot \Pi\)
+ Dot product : \(\text{scores}_{i,j} = Q_i \cdot \tilde{K}_j\)
+
Implementation Requirements
The solve function signature must remain unchanged.
Use only native features (no external libraries).
- \(\Pi\) is an orthogonal matrix (\(\Pi^T \Pi = I\)), stored row-major as float32.
- The codebook is sorted in ascending order (float32).
- K_idx contains uint8 indices in \([0, C)\).
Store the result in scores as float32.
@@ -72,17 +44,18 @@ Example
Codebook (\(C=4\)): \([-0.75,\; -0.25,\; 0.25,\; 0.75]\)
-Dequantized keys: look up centroids for each index, then multiply by \(\Pi\) (identity here):
+Step 1: Lookup centroids for each index:
\[
-K_\text{deq} = \begin{bmatrix}
+\tilde{Y} = \begin{bmatrix}
-0.75 & 0.75 \\
-0.25 & 0.25 \\
0.75 & -0.75
\end{bmatrix}
\]
-Output scores (\(B=2, S=3\)):
+Step 2: Rotate back (\(\Pi\) is identity here, so \(\tilde{K} = \tilde{Y}\)).
+Step 3: Dot products:
\[
-\text{scores} = Q \cdot K_\text{deq}^T = \begin{bmatrix}
+\text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix}
-0.75 & -0.25 & 0.75 \\
0.75 & 0.25 & -0.75
\end{bmatrix}
@@ -95,7 +68,9 @@
Constraints
1 ≤ S ≤ 65,536
1 ≤ D ≤ 256
2 ≤ C ≤ 256
- \(\Pi\) is an orthogonal matrix, codebook is sorted ascending
- All floating-point values are float32; key indices are uint8
+ \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
+ Codebook values are sorted in ascending order
+ K_idx values are in \([0, C)\)
+ All floating-point inputs are float32; key indices are uint8
Performance is measured with B = 32, S = 32,768, D = 128, C = 16
From 33db15a785e2d0dfbf419d487bdc01f815760ca9 Mon Sep 17 00:00:00 2001
From: James Song
Date: Fri, 27 Mar 2026 18:19:17 -0400
Subject: [PATCH 3/5] update turboquant to include residual correction
---
.../83_turboquant_attention/challenge.html | 202 ++++++++++++
.../hard/83_turboquant_attention/challenge.py | 222 +++++++++++++
.../starter/starter.cu | 6 +
.../starter/starter.cute.py | 5 +-
.../starter/starter.jax.py | 5 +-
.../starter/starter.mojo | 9 +
.../starter/starter.pytorch.py | 5 +-
.../starter/starter.triton.py | 21 ++
.../83_turboquant_attention/challenge.html | 76 -----
.../83_turboquant_attention/challenge.py | 308 ------------------
.../starter/starter.cu | 5 -
.../starter/starter.mojo | 9 -
.../starter/starter.triton.py | 18 -
13 files changed, 472 insertions(+), 419 deletions(-)
create mode 100644 challenges/hard/83_turboquant_attention/challenge.html
create mode 100644 challenges/hard/83_turboquant_attention/challenge.py
create mode 100644 challenges/hard/83_turboquant_attention/starter/starter.cu
rename challenges/{medium => hard}/83_turboquant_attention/starter/starter.cute.py (63%)
rename challenges/{medium => hard}/83_turboquant_attention/starter/starter.jax.py (64%)
create mode 100644 challenges/hard/83_turboquant_attention/starter/starter.mojo
rename challenges/{medium => hard}/83_turboquant_attention/starter/starter.pytorch.py (56%)
create mode 100644 challenges/hard/83_turboquant_attention/starter/starter.triton.py
delete mode 100644 challenges/medium/83_turboquant_attention/challenge.html
delete mode 100644 challenges/medium/83_turboquant_attention/challenge.py
delete mode 100644 challenges/medium/83_turboquant_attention/starter/starter.cu
delete mode 100644 challenges/medium/83_turboquant_attention/starter/starter.mojo
delete mode 100644 challenges/medium/83_turboquant_attention/starter/starter.triton.py
diff --git a/challenges/hard/83_turboquant_attention/challenge.html b/challenges/hard/83_turboquant_attention/challenge.html
new file mode 100644
index 00000000..ff34a028
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/challenge.html
@@ -0,0 +1,202 @@
+
+ Implement attention score computation against a
+ TurboQuant -compressed
+ KV cache. TurboQuant compresses each key vector to uint8 codebook indices plus a 1-bit
+ residual correction (QJL), reducing KV cache memory by up to 6x. Your task: dequantize the
+ compressed keys and compute dot-product attention scores against full-precision queries.
+
+
+
+
+
+
+ TurboQuant Dequantization Pipeline (per key vector)
+
+
+
+
+ Stage 1: MSE
+
+
+
+ K_idx
+ [S, D] uint8
+
+
+ →
+
+
+
+ codebook[ K_idx ]
+ centroid lookup
+
+
+ →
+
+
+
+ × Π
+ rotate back
+
+
+ =
+
+
+
+ K̃_mse
+ [S, D] float32
+
+
+
+
+ Stage 2: QJL residual
+
+
+
+ σ
+ [S,D] ±1
+
+
+ →
+
+
+ σ · M
+ project
+
+
+ →
+
+
+ × √(π/2)/D × γ
+ scale by norm
+
+
+ =
+
+
+
+ K̃_res
+ [S, D] float32
+
+
+
+
+ Combine + Score
+
+
+
+ K̃_mse
+
+ +
+
+
+ K̃_res
+
+ =
+
+
+ K̃
+
+
+ then:
+
+
+ Q
+
+ ·
+
+
+ K̃ᵀ
+
+ =
+
+
+ scores
+
+
+ Π = orthogonal rotation [D×D]
+ M = Gaussian projection [D×D]
+ σ = sign bits ±1, γ = ‖residual‖₂
+
+
+
+ Background — how the keys were compressed (already done for you, not part of the challenge):
+
+
+ Rotate : multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
+ coordinate follow a Beta distribution, so a single fixed codebook works for all coordinates.
+ Scalar quantize : replace each coordinate of \(y\) with the index of its nearest
+ codebook centroid \(\rightarrow K_\text{idx}\) (uint8).
+ Residual correction : MSE quantization loses information. Compute the residual
+ \(r = K - \tilde{K}_\text{mse}\), then store:
+
+ \(\sigma = \text{sign}(M \cdot r) \in \{-1,+1\}^D\) — direction (int8)
+ \(\gamma = \|r\|_2\) — magnitude (float32 scalar per key)
+
+ where \(M \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix (S_mat in code).
+
+
+
+
+ What you compute — dequantize and score:
+
+
+ MSE dequantize : look up centroids, undo the rotation:
+ \[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]
+ Residual dequantize : reconstruct the residual correction:
+ \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot M\]
+ The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.
+ Combine :
+ \(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)
+ Dot product :
+ \(\text{scores}_{b,s} = Q_b \cdot \tilde{K}_s\)
+
+
+ The residual correction makes the inner product unbiased :
+ \(\mathbb{E}[\langle Q, \tilde{K} \rangle] = \langle Q, K \rangle\).
+
+
+Implementation Requirements
+
+ The solve function signature must remain unchanged.
+ Use only native features (no external libraries).
+ Store the result in scores as float32.
+
+
+Example
+
+ Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(M = I\), \(\gamma = \mathbf{0}\) (residual correction disabled):
+
+
+ \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), \quad
+ \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), \quad
+ codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
+
+
+ Step 1 — MSE lookup and rotate back (\(\Pi = I\)):
+ \[
+ \tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
+ \]
+ Step 2 — Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
+
+
+ Output:
+ \[
+ \text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix}
+ \]
+
+
+Constraints
+
+ 1 ≤ B ≤ 32
+ 1 ≤ S ≤ 65,536
+ 1 ≤ D ≤ 256
+ 2 ≤ C ≤ 256
+ \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
+ \(M\) (S_mat) has i.i.d. \(\mathcal{N}(0,1)\) entries
+ qjl_signs (\(\sigma\)) values are in \(\{-1, +1\}\) (int8)
+ K_idx values are in \([0, C)\) (uint8)
+ All floating-point inputs are float32
+ Performance is measured with B = 32, S = 32,768, D = 128, C = 16
+
diff --git a/challenges/hard/83_turboquant_attention/challenge.py b/challenges/hard/83_turboquant_attention/challenge.py
new file mode 100644
index 00000000..e6723af1
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/challenge.py
@@ -0,0 +1,222 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="TurboQuant KV Cache Attention",
+ atol=1e-3,
+ rtol=1e-3,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ qjl_signs: torch.Tensor,
+ gamma: torch.Tensor,
+ Pi: torch.Tensor,
+ S_mat: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+ ):
+ assert Q.shape == (B, D)
+ assert K_idx.shape == (S, D)
+ assert qjl_signs.shape == (S, D)
+ assert gamma.shape == (S,)
+ assert Pi.shape == (D, D)
+ assert S_mat.shape == (D, D)
+ assert codebook.shape == (C,)
+ assert scores.shape == (B, S)
+ assert Q.dtype == torch.float32
+ assert K_idx.dtype == torch.uint8
+ assert qjl_signs.dtype == torch.int8
+ assert gamma.dtype == torch.float32
+ assert Pi.dtype == torch.float32
+ assert S_mat.dtype == torch.float32
+ assert codebook.dtype == torch.float32
+ assert scores.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K_idx.device.type == "cuda"
+ assert qjl_signs.device.type == "cuda"
+ assert gamma.device.type == "cuda"
+ assert Pi.device.type == "cuda"
+ assert S_mat.device.type == "cuda"
+ assert codebook.device.type == "cuda"
+ assert scores.device.type == "cuda"
+
+ # Stage 1: MSE dequantization — lookup centroids, rotate back
+ K_centroids = codebook[K_idx.long()] # [S, D]
+ K_mse = K_centroids @ Pi # [S, D] (row convention: ỹ @ Π = Π^T · ỹ)
+
+ # Stage 2: QJL dequantization — reconstruct residual correction
+ scale = math.sqrt(math.pi / 2.0) / D
+ K_qjl = scale * gamma.unsqueeze(1) * (qjl_signs.float() @ S_mat) # [S, D]
+
+ # Combined dequantization
+ K_deq = K_mse + K_qjl # [S, D]
+
+ # Attention scores
+ scores.copy_(Q @ K_deq.T) # [B, S]
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"),
+ "qjl_signs": (ctypes.POINTER(ctypes.c_int8), "in"),
+ "gamma": (ctypes.POINTER(ctypes.c_float), "in"),
+ "Pi": (ctypes.POINTER(ctypes.c_float), "in"),
+ "S_mat": (ctypes.POINTER(ctypes.c_float), "in"),
+ "codebook": (ctypes.POINTER(ctypes.c_float), "in"),
+ "scores": (ctypes.POINTER(ctypes.c_float), "out"),
+ "B": (ctypes.c_int, "in"),
+ "S": (ctypes.c_int, "in"),
+ "D": (ctypes.c_int, "in"),
+ "C": (ctypes.c_int, "in"),
+ }
+
+ def _make_rotation(self, D):
+ G = torch.randn(D, D, device="cuda")
+ Q, _ = torch.linalg.qr(G)
+ return Q
+
+ def _make_codebook(self, C, scale=1.0):
+ return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32)
+
+ def _encode_keys(self, K, Pi, S_mat, codebook):
+ """Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual."""
+ S, D = K.shape
+ C = codebook.shape[0]
+
+ # Stage 1: MSE encoding
+ Y = K @ Pi.T # rotate into quantization space
+ # Scalar quantize each coordinate to nearest centroid
+ diffs = Y.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0) # [S, D, C]
+ K_idx = diffs.abs().argmin(dim=-1).to(torch.uint8) # [S, D]
+
+ # MSE dequantization (to compute residual)
+ K_centroids = codebook[K_idx.long()] # [S, D]
+ K_mse = K_centroids @ Pi # [S, D]
+
+ # Stage 2: QJL encoding of residual
+ residual = K - K_mse # [S, D]
+ gamma = residual.norm(dim=1) # [S]
+ proj = residual @ S_mat.T # [S, D] (row convention for S · r)
+ qjl_signs = torch.sign(proj).to(torch.int8) # [S, D]
+ # Ensure no zeros (sign(0)=0 → map to +1)
+ qjl_signs[qjl_signs == 0] = 1
+
+ return K_idx, qjl_signs, gamma
+
+ def _make_test_case(self, B, S_seq, D, C, zero_q=False, seed=42):
+ torch.manual_seed(seed)
+ device = "cuda"
+
+ Pi = self._make_rotation(D)
+ S_mat = torch.randn(D, D, device=device, dtype=torch.float32)
+ codebook = self._make_codebook(C)
+
+ if zero_q:
+ Q = torch.zeros(B, D, device=device, dtype=torch.float32)
+ else:
+ Q = torch.randn(B, D, device=device, dtype=torch.float32) * 0.5
+
+ # Generate realistic keys and encode them
+ K = torch.randn(S_seq, D, device=device, dtype=torch.float32) * 0.3
+ K_idx, qjl_signs, gamma = self._encode_keys(K, Pi, S_mat, codebook)
+
+ scores = torch.zeros(B, S_seq, device=device, dtype=torch.float32)
+
+ return {
+ "Q": Q,
+ "K_idx": K_idx,
+ "qjl_signs": qjl_signs,
+ "gamma": gamma,
+ "Pi": Pi,
+ "S_mat": S_mat,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S_seq,
+ "D": D,
+ "C": C,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ device = "cuda"
+ B, S, D, C = 2, 3, 2, 4
+
+ Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device, dtype=torch.float32)
+ K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device=device, dtype=torch.uint8)
+ # QJL signs: all +1 for simplicity
+ qjl_signs = torch.ones(S, D, device=device, dtype=torch.int8)
+ # gamma = 0: no QJL correction (reduces to MSE-only for this example)
+ gamma = torch.zeros(S, device=device, dtype=torch.float32)
+ Pi = torch.eye(D, device=device, dtype=torch.float32)
+ S_mat = torch.eye(D, device=device, dtype=torch.float32)
+ codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device=device, dtype=torch.float32)
+ scores = torch.zeros(B, S, device=device, dtype=torch.float32)
+
+ return {
+ "Q": Q,
+ "K_idx": K_idx,
+ "qjl_signs": qjl_signs,
+ "gamma": gamma,
+ "Pi": Pi,
+ "S_mat": S_mat,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ tests = []
+
+ # Edge: single query, single key, D=1
+ tests.append(self._make_test_case(1, 1, 1, 2, seed=1))
+
+ # Edge: zero query
+ tests.append(self._make_test_case(2, 3, 4, 4, zero_q=True, seed=2))
+
+ # Edge: small with negative queries
+ tests.append(self._make_test_case(2, 4, 4, 4, seed=3))
+
+ # Power-of-2: B=4, S=16, D=32, C=8
+ tests.append(self._make_test_case(4, 16, 32, 8, seed=10))
+
+ # Power-of-2: B=8, S=64, D=64, C=16
+ tests.append(self._make_test_case(8, 64, 64, 16, seed=20))
+
+ # Power-of-2: B=16, S=128, D=128, C=16
+ tests.append(self._make_test_case(16, 128, 128, 16, seed=30))
+
+ # Non-power-of-2: B=3, S=30, D=50, C=8
+ tests.append(self._make_test_case(3, 30, 50, 8, seed=40))
+
+ # Non-power-of-2: B=7, S=255, D=100, C=16
+ tests.append(self._make_test_case(7, 255, 100, 16, seed=50))
+
+ # Realistic: B=16, S=4096, D=128, C=16
+ tests.append(self._make_test_case(16, 4096, 128, 16, seed=60))
+
+ # Realistic: B=32, S=8192, D=128, C=8
+ tests.append(self._make_test_case(32, 8192, 128, 8, seed=70))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ return self._make_test_case(32, 32768, 128, 16, seed=0)
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.cu b/challenges/hard/83_turboquant_attention/starter/starter.cu
new file mode 100644
index 00000000..18478b0c
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.cu
@@ -0,0 +1,6 @@
+#include
+
+// Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers
+extern "C" void solve(const float* Q, const unsigned char* K_idx, const signed char* qjl_signs,
+ const float* gamma, const float* Pi, const float* S_mat,
+ const float* codebook, float* scores, int B, int S, int D, int C) {}
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cute.py b/challenges/hard/83_turboquant_attention/starter/starter.cute.py
similarity index 63%
rename from challenges/medium/83_turboquant_attention/starter/starter.cute.py
rename to challenges/hard/83_turboquant_attention/starter/starter.cute.py
index ebd728fa..27a31a80 100644
--- a/challenges/medium/83_turboquant_attention/starter/starter.cute.py
+++ b/challenges/hard/83_turboquant_attention/starter/starter.cute.py
@@ -2,12 +2,15 @@
import cutlass.cute as cute
-# Q, K_idx, Pi, codebook, scores are tensors on the GPU
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
@cute.jit
def solve(
Q: cute.Tensor,
K_idx: cute.Tensor,
+ qjl_signs: cute.Tensor,
+ gamma: cute.Tensor,
Pi: cute.Tensor,
+ S_mat: cute.Tensor,
codebook: cute.Tensor,
scores: cute.Tensor,
B: cute.Int32,
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.jax.py b/challenges/hard/83_turboquant_attention/starter/starter.jax.py
similarity index 64%
rename from challenges/medium/83_turboquant_attention/starter/starter.jax.py
rename to challenges/hard/83_turboquant_attention/starter/starter.jax.py
index 3c38dbb5..b487a049 100644
--- a/challenges/medium/83_turboquant_attention/starter/starter.jax.py
+++ b/challenges/hard/83_turboquant_attention/starter/starter.jax.py
@@ -2,12 +2,15 @@
import jax.numpy as jnp
-# Q, K_idx, Pi, codebook are tensors on GPU
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook are tensors on GPU
@jax.jit
def solve(
Q: jax.Array,
K_idx: jax.Array,
+ qjl_signs: jax.Array,
+ gamma: jax.Array,
Pi: jax.Array,
+ S_mat: jax.Array,
codebook: jax.Array,
B: int,
S: int,
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.mojo b/challenges/hard/83_turboquant_attention/starter/starter.mojo
new file mode 100644
index 00000000..43d74b27
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.mojo
@@ -0,0 +1,9 @@
+from gpu.host import DeviceContext
+from gpu.id import block_dim, block_idx, thread_idx
+from memory import UnsafePointer
+from math import ceildiv
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers
+@export
+def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], qjl_signs: UnsafePointer[Int8], gamma: UnsafePointer[Float32], Pi: UnsafePointer[Float32], S_mat: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32):
+ pass
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py b/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py
similarity index 56%
rename from challenges/medium/83_turboquant_attention/starter/starter.pytorch.py
rename to challenges/hard/83_turboquant_attention/starter/starter.pytorch.py
index a78ee78c..8b706f88 100644
--- a/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py
+++ b/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py
@@ -1,11 +1,14 @@
import torch
-# Q, K_idx, Pi, codebook, scores are tensors on the GPU
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
def solve(
Q: torch.Tensor,
K_idx: torch.Tensor,
+ qjl_signs: torch.Tensor,
+ gamma: torch.Tensor,
Pi: torch.Tensor,
+ S_mat: torch.Tensor,
codebook: torch.Tensor,
scores: torch.Tensor,
B: int,
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.triton.py b/challenges/hard/83_turboquant_attention/starter/starter.triton.py
new file mode 100644
index 00000000..ccf5176c
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.triton.py
@@ -0,0 +1,21 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ qjl_signs: torch.Tensor,
+ gamma: torch.Tensor,
+ Pi: torch.Tensor,
+ S_mat: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+):
+ pass
diff --git a/challenges/medium/83_turboquant_attention/challenge.html b/challenges/medium/83_turboquant_attention/challenge.html
deleted file mode 100644
index 91f88bd3..00000000
--- a/challenges/medium/83_turboquant_attention/challenge.html
+++ /dev/null
@@ -1,76 +0,0 @@
-
- Implement attention score computation against a quantized KV cache. During LLM inference, the KV cache can dominate memory. TurboQuant addresses this by compressing each key vector down to uint8 codebook indices, reducing memory by up to 4×. Your task is to compute attention scores between full-precision queries and these compressed keys.
-
-
-
- How TurboQuant encodes keys (already done for you): each key vector is multiplied by a random orthogonal matrix \(\Pi\), then each coordinate is replaced by the index of its nearest centroid in a codebook.
-
-
-
- What you need to compute : given queries \(Q\) and quantized key indices \(K_\text{idx}\), dequantize each key and compute dot products:
-
-
- Lookup : replace each index with its codebook value, producing a centroid vector \(\tilde{Y}_j\)
- Rotate back : recover the approximate key \(\tilde{K}_j = \tilde{Y}_j \cdot \Pi\)
- Dot product : \(\text{scores}_{i,j} = Q_i \cdot \tilde{K}_j\)
-
-
-Implementation Requirements
-
- The solve function signature must remain unchanged.
- Use only native features (no external libraries).
- Store the result in scores as float32.
-
-
-Example
-
-Input:
-Query matrix \(Q\) (\(B=2, D=2\)):
-\[
-\begin{bmatrix}
-1.0 & 0.0 \\
-0.0 & 1.0
-\end{bmatrix}
-\]
-Quantized key indices \(K_\text{idx}\) (\(S=3, D=2\), uint8):
-\[
-\begin{bmatrix}
-0 & 3 \\
-1 & 2 \\
-3 & 0
-\end{bmatrix}
-\]
-Rotation matrix \(\Pi = I_{2 \times 2}\) (identity)
-Codebook (\(C=4\)): \([-0.75,\; -0.25,\; 0.25,\; 0.75]\)
-
-
-Step 1: Lookup centroids for each index:
-\[
-\tilde{Y} = \begin{bmatrix}
--0.75 & 0.75 \\
--0.25 & 0.25 \\
-0.75 & -0.75
-\end{bmatrix}
-\]
-Step 2: Rotate back (\(\Pi\) is identity here, so \(\tilde{K} = \tilde{Y}\)).
-Step 3: Dot products:
-\[
-\text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix}
--0.75 & -0.25 & 0.75 \\
-0.75 & 0.25 & -0.75
-\end{bmatrix}
-\]
-
-
-Constraints
-
- 1 ≤ B ≤ 32
- 1 ≤ S ≤ 65,536
- 1 ≤ D ≤ 256
- 2 ≤ C ≤ 256
- \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
- Codebook values are sorted in ascending order
- K_idx values are in \([0, C)\)
- All floating-point inputs are float32; key indices are uint8
- Performance is measured with B = 32, S = 32,768, D = 128, C = 16
-
diff --git a/challenges/medium/83_turboquant_attention/challenge.py b/challenges/medium/83_turboquant_attention/challenge.py
deleted file mode 100644
index f6188e6c..00000000
--- a/challenges/medium/83_turboquant_attention/challenge.py
+++ /dev/null
@@ -1,308 +0,0 @@
-import ctypes
-from typing import Any, Dict, List
-
-import torch
-from core.challenge_base import ChallengeBase
-
-
-class Challenge(ChallengeBase):
- def __init__(self):
- super().__init__(
- name="TurboQuant KV Cache Attention",
- atol=1e-3,
- rtol=1e-3,
- num_gpus=1,
- access_tier="free",
- )
-
- def reference_impl(
- self,
- Q: torch.Tensor,
- K_idx: torch.Tensor,
- Pi: torch.Tensor,
- codebook: torch.Tensor,
- scores: torch.Tensor,
- B: int,
- S: int,
- D: int,
- C: int,
- ):
- assert Q.shape == (B, D)
- assert K_idx.shape == (S, D)
- assert Pi.shape == (D, D)
- assert codebook.shape == (C,)
- assert scores.shape == (B, S)
- assert Q.dtype == torch.float32
- assert K_idx.dtype == torch.uint8
- assert Pi.dtype == torch.float32
- assert codebook.dtype == torch.float32
- assert scores.dtype == torch.float32
- assert Q.device.type == "cuda"
- assert K_idx.device.type == "cuda"
- assert Pi.device.type == "cuda"
- assert codebook.device.type == "cuda"
- assert scores.device.type == "cuda"
-
- # Dequantize keys: lookup centroids then rotate back
- K_centroids = codebook[K_idx.long()] # S x D
- K_deq = K_centroids @ Pi # S x D
-
- # Compute attention scores
- scores.copy_(Q @ K_deq.T) # B x S
-
- def get_solve_signature(self) -> Dict[str, tuple]:
- return {
- "Q": (ctypes.POINTER(ctypes.c_float), "in"),
- "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"),
- "Pi": (ctypes.POINTER(ctypes.c_float), "in"),
- "codebook": (ctypes.POINTER(ctypes.c_float), "in"),
- "scores": (ctypes.POINTER(ctypes.c_float), "out"),
- "B": (ctypes.c_int, "in"),
- "S": (ctypes.c_int, "in"),
- "D": (ctypes.c_int, "in"),
- "C": (ctypes.c_int, "in"),
- }
-
- def _make_rotation(self, D):
- G = torch.randn(D, D, device="cuda")
- Q, _ = torch.linalg.qr(G)
- return Q
-
- def _make_codebook(self, C, scale=1.0):
- return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32)
-
- def generate_example_test(self) -> Dict[str, Any]:
- B, S, D, C = 2, 3, 2, 4
- Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device="cuda", dtype=torch.float32)
- K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device="cuda", dtype=torch.uint8)
- Pi = torch.eye(D, device="cuda", dtype=torch.float32)
- codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device="cuda", dtype=torch.float32)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- return {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
-
- def generate_functional_test(self) -> List[Dict[str, Any]]:
- tests = []
-
- # Edge case: single query, single key, D=1, C=2
- B, S, D, C = 1, 1, 1, 2
- Q = torch.tensor([[0.5]], device="cuda", dtype=torch.float32)
- K_idx = torch.tensor([[1]], device="cuda", dtype=torch.uint8)
- Pi = torch.eye(D, device="cuda", dtype=torch.float32)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Edge case: zeros query
- B, S, D, C = 2, 3, 4, 4
- Q = torch.zeros(B, D, device="cuda", dtype=torch.float32)
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Edge case: negative query values
- B, S, D, C = 2, 4, 4, 4
- Q = torch.tensor(
- [[-0.5, -0.3, -0.8, -0.1], [-1.0, -0.5, -0.2, -0.9]],
- device="cuda",
- dtype=torch.float32,
- )
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Power-of-2: B=4, S=16, D=32, C=8
- B, S, D, C = 4, 16, 32, 8
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.5
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C, scale=1.5)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Power-of-2: B=8, S=64, D=64, C=16
- B, S, D, C = 8, 64, 64, 16
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Power-of-2: B=16, S=128, D=128, C=16
- B, S, D, C = 16, 128, 128, 16
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Non-power-of-2: B=3, S=30, D=50, C=8
- B, S, D, C = 3, 30, 50, 8
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.4
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Non-power-of-2: B=7, S=255, D=100, C=16
- B, S, D, C = 7, 255, 100, 16
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.6
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C, scale=1.5)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- # Realistic: B=16, S=4096, D=128, C=16
- B, S, D, C = 16, 4096, 128, 16
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- tests.append(
- {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
- )
-
- return tests
-
- def generate_performance_test(self) -> Dict[str, Any]:
- B, S, D, C = 32, 32768, 128, 16
- Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3
- K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8)
- Pi = self._make_rotation(D)
- codebook = self._make_codebook(C)
- scores = torch.zeros(B, S, device="cuda", dtype=torch.float32)
- return {
- "Q": Q,
- "K_idx": K_idx,
- "Pi": Pi,
- "codebook": codebook,
- "scores": scores,
- "B": B,
- "S": S,
- "D": D,
- "C": C,
- }
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cu b/challenges/medium/83_turboquant_attention/starter/starter.cu
deleted file mode 100644
index be318f3c..00000000
--- a/challenges/medium/83_turboquant_attention/starter/starter.cu
+++ /dev/null
@@ -1,5 +0,0 @@
-#include
-
-// Q, K_idx, Pi, codebook, scores are device pointers
-extern "C" void solve(const float* Q, const unsigned char* K_idx, const float* Pi,
- const float* codebook, float* scores, int B, int S, int D, int C) {}
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.mojo b/challenges/medium/83_turboquant_attention/starter/starter.mojo
deleted file mode 100644
index 353883cd..00000000
--- a/challenges/medium/83_turboquant_attention/starter/starter.mojo
+++ /dev/null
@@ -1,9 +0,0 @@
-from gpu.host import DeviceContext
-from gpu.id import block_dim, block_idx, thread_idx
-from memory import UnsafePointer
-from math import ceildiv
-
-# Q, K_idx, Pi, codebook, scores are device pointers
-@export
-def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], Pi: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32):
- pass
diff --git a/challenges/medium/83_turboquant_attention/starter/starter.triton.py b/challenges/medium/83_turboquant_attention/starter/starter.triton.py
deleted file mode 100644
index f21c5eea..00000000
--- a/challenges/medium/83_turboquant_attention/starter/starter.triton.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-import triton
-import triton.language as tl
-
-
-# q, k_idx, pi, codebook, scores are tensors on the GPU
-def solve(
- q: torch.Tensor,
- k_idx: torch.Tensor,
- pi: torch.Tensor,
- codebook: torch.Tensor,
- scores: torch.Tensor,
- B: int,
- S: int,
- D: int,
- C: int,
-):
- pass
From bad839a975d39191a8023c269ec37e032a078159 Mon Sep 17 00:00:00 2001
From: James Song
Date: Fri, 27 Mar 2026 18:22:30 -0400
Subject: [PATCH 4/5] fix lint
---
challenges/hard/83_turboquant_attention/challenge.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/challenges/hard/83_turboquant_attention/challenge.py b/challenges/hard/83_turboquant_attention/challenge.py
index e6723af1..6a7e07b8 100644
--- a/challenges/hard/83_turboquant_attention/challenge.py
+++ b/challenges/hard/83_turboquant_attention/challenge.py
@@ -97,7 +97,6 @@ def _make_codebook(self, C, scale=1.0):
def _encode_keys(self, K, Pi, S_mat, codebook):
"""Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual."""
S, D = K.shape
- C = codebook.shape[0]
# Stage 1: MSE encoding
Y = K @ Pi.T # rotate into quantization space
From 8b125d89c26bb305a50828153aec54bfb2d117ac Mon Sep 17 00:00:00 2001
From: Kunal Mansukhani
Date: Fri, 27 Mar 2026 23:34:03 -0400
Subject: [PATCH 5/5] Clean up spec
---
.../83_turboquant_attention/challenge.html | 140 ++----------------
1 file changed, 14 insertions(+), 126 deletions(-)
diff --git a/challenges/hard/83_turboquant_attention/challenge.html b/challenges/hard/83_turboquant_attention/challenge.html
index ff34a028..e6397aef 100644
--- a/challenges/hard/83_turboquant_attention/challenge.html
+++ b/challenges/hard/83_turboquant_attention/challenge.html
@@ -6,122 +6,8 @@
compressed keys and compute dot-product attention scores against full-precision queries.
-
-
-
-
- TurboQuant Dequantization Pipeline (per key vector)
-
-
-
-
- Stage 1: MSE
-
-
-
- K_idx
- [S, D] uint8
-
-
- →
-
-
-
- codebook[ K_idx ]
- centroid lookup
-
-
- →
-
-
-
- × Π
- rotate back
-
-
- =
-
-
-
- K̃_mse
- [S, D] float32
-
-
-
-
- Stage 2: QJL residual
-
-
-
- σ
- [S,D] ±1
-
-
- →
-
-
- σ · M
- project
-
-
- →
-
-
- × √(π/2)/D × γ
- scale by norm
-
-
- =
-
-
-
- K̃_res
- [S, D] float32
-
-
-
-
- Combine + Score
-
-
-
- K̃_mse
-
- +
-
-
- K̃_res
-
- =
-
-
- K̃
-
-
- then:
-
-
- Q
-
- ·
-
-
- K̃ᵀ
-
- =
-
-
- scores
-
-
- Π = orthogonal rotation [D×D]
- M = Gaussian projection [D×D]
- σ = sign bits ±1, γ = ‖residual‖₂
-
-
- Background — how the keys were compressed (already done for you, not part of the challenge):
+ Background - how the keys were compressed (already done for you, not part of the challenge):
Rotate : multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
@@ -131,21 +17,21 @@
Residual correction : MSE quantization loses information. Compute the residual
\(r = K - \tilde{K}_\text{mse}\), then store:
- \(\sigma = \text{sign}(M \cdot r) \in \{-1,+1\}^D\) — direction (int8)
- \(\gamma = \|r\|_2\) — magnitude (float32 scalar per key)
+ \(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (int8)
+ \(\gamma = \|r\|_2\) - magnitude (float32 scalar per key)
- where \(M \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix (S_mat in code).
+ where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.
- What you compute — dequantize and score:
+ What you compute - dequantize and score:
MSE dequantize : look up centroids, undo the rotation:
\[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]
Residual dequantize : reconstruct the residual correction:
- \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot M\]
+ \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\]
The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.
Combine :
\(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)
@@ -166,19 +52,20 @@ Implementation Requirements
Example
- Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(M = I\), \(\gamma = \mathbf{0}\) (residual correction disabled):
+ Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled),
+ \(\sigma = \mathbf{1}\) (all +1):
- \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), \quad
- \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), \quad
+ \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\),
+ \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\),
codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
- Step 1 — MSE lookup and rotate back (\(\Pi = I\)):
+ Step 1 - MSE lookup and rotate back (\(\Pi = I\)):
\[
\tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
\]
- Step 2 — Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
+ Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
Output:
@@ -194,7 +81,8 @@
Constraints
1 ≤ D ≤ 256
2 ≤ C ≤ 256
\(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
- \(M\) (S_mat) has i.i.d. \(\mathcal{N}(0,1)\) entries
+ S_mat has i.i.d. \(\mathcal{N}(0,1)\) entries
+ gamma has shape \([S]\) (one \(\ell_2\) norm per key vector, float32)
qjl_signs (\(\sigma\)) values are in \(\{-1, +1\}\) (int8)
K_idx values are in \([0, C)\) (uint8)
All floating-point inputs are float32