diff --git a/challenges/hard/83_turboquant_attention/challenge.html b/challenges/hard/83_turboquant_attention/challenge.html
new file mode 100644
index 00000000..e6397aef
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/challenge.html
@@ -0,0 +1,90 @@
+
+ Implement attention score computation against a
+ TurboQuant-compressed
+ KV cache. TurboQuant compresses each key vector to uint8 codebook indices plus a 1-bit
+ residual correction (QJL), reducing KV cache memory by up to 6x. Your task: dequantize the
+ compressed keys and compute dot-product attention scores against full-precision queries.
+
+
+
+ Background - how the keys were compressed (already done for you, not part of the challenge):
+
+
+ - Rotate: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each
+ coordinate follow a Beta distribution, so a single fixed codebook works for all coordinates.
+ - Scalar quantize: replace each coordinate of \(y\) with the index of its nearest
+ codebook centroid \(\rightarrow K_\text{idx}\) (
uint8).
+ - Residual correction: MSE quantization loses information. Compute the residual
+ \(r = K - \tilde{K}_\text{mse}\), then store:
+
+ - \(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (
int8)
+ - \(\gamma = \|r\|_2\) - magnitude (
float32 scalar per key)
+
+ where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.
+
+
+
+
+ What you compute - dequantize and score:
+
+
+ - MSE dequantize: look up centroids, undo the rotation:
+ \[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]
+ - Residual dequantize: reconstruct the residual correction:
+ \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\]
+ The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.
+ - Combine:
+ \(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)
+ - Dot product:
+ \(\text{scores}_{b,s} = Q_b \cdot \tilde{K}_s\)
+
+
+ The residual correction makes the inner product unbiased:
+ \(\mathbb{E}[\langle Q, \tilde{K} \rangle] = \langle Q, K \rangle\).
+
+
+Implementation Requirements
+
+ - The
solve function signature must remain unchanged.
+ - Use only native features (no external libraries).
+ - Store the result in
scores as float32.
+
+
+Example
+
+ Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled),
+ \(\sigma = \mathbf{1}\) (all +1):
+
+
+ \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\),
+ \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\),
+ codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)
+
+
+ Step 1 - MSE lookup and rotate back (\(\Pi = I\)):
+ \[
+ \tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix}
+ \]
+ Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).
+
+
+ Output:
+ \[
+ \text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix}
+ \]
+
+
+Constraints
+
+ - 1 ≤
B ≤ 32
+ - 1 ≤
S ≤ 65,536
+ - 1 ≤
D ≤ 256
+ - 2 ≤
C ≤ 256
+ - \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
+ S_mat has i.i.d. \(\mathcal{N}(0,1)\) entries
+ gamma has shape \([S]\) (one \(\ell_2\) norm per key vector, float32)
+ qjl_signs (\(\sigma\)) values are in \(\{-1, +1\}\) (int8)
+ K_idx values are in \([0, C)\) (uint8)
+ - All floating-point inputs are
float32
+ - Performance is measured with
B = 32, S = 32,768, D = 128, C = 16
+
diff --git a/challenges/hard/83_turboquant_attention/challenge.py b/challenges/hard/83_turboquant_attention/challenge.py
new file mode 100644
index 00000000..6a7e07b8
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/challenge.py
@@ -0,0 +1,221 @@
+import ctypes
+import math
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ def __init__(self):
+ super().__init__(
+ name="TurboQuant KV Cache Attention",
+ atol=1e-3,
+ rtol=1e-3,
+ num_gpus=1,
+ access_tier="free",
+ )
+
+ def reference_impl(
+ self,
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ qjl_signs: torch.Tensor,
+ gamma: torch.Tensor,
+ Pi: torch.Tensor,
+ S_mat: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+ ):
+ assert Q.shape == (B, D)
+ assert K_idx.shape == (S, D)
+ assert qjl_signs.shape == (S, D)
+ assert gamma.shape == (S,)
+ assert Pi.shape == (D, D)
+ assert S_mat.shape == (D, D)
+ assert codebook.shape == (C,)
+ assert scores.shape == (B, S)
+ assert Q.dtype == torch.float32
+ assert K_idx.dtype == torch.uint8
+ assert qjl_signs.dtype == torch.int8
+ assert gamma.dtype == torch.float32
+ assert Pi.dtype == torch.float32
+ assert S_mat.dtype == torch.float32
+ assert codebook.dtype == torch.float32
+ assert scores.dtype == torch.float32
+ assert Q.device.type == "cuda"
+ assert K_idx.device.type == "cuda"
+ assert qjl_signs.device.type == "cuda"
+ assert gamma.device.type == "cuda"
+ assert Pi.device.type == "cuda"
+ assert S_mat.device.type == "cuda"
+ assert codebook.device.type == "cuda"
+ assert scores.device.type == "cuda"
+
+ # Stage 1: MSE dequantization — lookup centroids, rotate back
+ K_centroids = codebook[K_idx.long()] # [S, D]
+ K_mse = K_centroids @ Pi # [S, D] (row convention: ỹ @ Π = Π^T · ỹ)
+
+ # Stage 2: QJL dequantization — reconstruct residual correction
+ scale = math.sqrt(math.pi / 2.0) / D
+ K_qjl = scale * gamma.unsqueeze(1) * (qjl_signs.float() @ S_mat) # [S, D]
+
+ # Combined dequantization
+ K_deq = K_mse + K_qjl # [S, D]
+
+ # Attention scores
+ scores.copy_(Q @ K_deq.T) # [B, S]
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "Q": (ctypes.POINTER(ctypes.c_float), "in"),
+ "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"),
+ "qjl_signs": (ctypes.POINTER(ctypes.c_int8), "in"),
+ "gamma": (ctypes.POINTER(ctypes.c_float), "in"),
+ "Pi": (ctypes.POINTER(ctypes.c_float), "in"),
+ "S_mat": (ctypes.POINTER(ctypes.c_float), "in"),
+ "codebook": (ctypes.POINTER(ctypes.c_float), "in"),
+ "scores": (ctypes.POINTER(ctypes.c_float), "out"),
+ "B": (ctypes.c_int, "in"),
+ "S": (ctypes.c_int, "in"),
+ "D": (ctypes.c_int, "in"),
+ "C": (ctypes.c_int, "in"),
+ }
+
+ def _make_rotation(self, D):
+ G = torch.randn(D, D, device="cuda")
+ Q, _ = torch.linalg.qr(G)
+ return Q
+
+ def _make_codebook(self, C, scale=1.0):
+ return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32)
+
+ def _encode_keys(self, K, Pi, S_mat, codebook):
+ """Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual."""
+ S, D = K.shape
+
+ # Stage 1: MSE encoding
+ Y = K @ Pi.T # rotate into quantization space
+ # Scalar quantize each coordinate to nearest centroid
+ diffs = Y.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0) # [S, D, C]
+ K_idx = diffs.abs().argmin(dim=-1).to(torch.uint8) # [S, D]
+
+ # MSE dequantization (to compute residual)
+ K_centroids = codebook[K_idx.long()] # [S, D]
+ K_mse = K_centroids @ Pi # [S, D]
+
+ # Stage 2: QJL encoding of residual
+ residual = K - K_mse # [S, D]
+ gamma = residual.norm(dim=1) # [S]
+ proj = residual @ S_mat.T # [S, D] (row convention for S · r)
+ qjl_signs = torch.sign(proj).to(torch.int8) # [S, D]
+ # Ensure no zeros (sign(0)=0 → map to +1)
+ qjl_signs[qjl_signs == 0] = 1
+
+ return K_idx, qjl_signs, gamma
+
+ def _make_test_case(self, B, S_seq, D, C, zero_q=False, seed=42):
+ torch.manual_seed(seed)
+ device = "cuda"
+
+ Pi = self._make_rotation(D)
+ S_mat = torch.randn(D, D, device=device, dtype=torch.float32)
+ codebook = self._make_codebook(C)
+
+ if zero_q:
+ Q = torch.zeros(B, D, device=device, dtype=torch.float32)
+ else:
+ Q = torch.randn(B, D, device=device, dtype=torch.float32) * 0.5
+
+ # Generate realistic keys and encode them
+ K = torch.randn(S_seq, D, device=device, dtype=torch.float32) * 0.3
+ K_idx, qjl_signs, gamma = self._encode_keys(K, Pi, S_mat, codebook)
+
+ scores = torch.zeros(B, S_seq, device=device, dtype=torch.float32)
+
+ return {
+ "Q": Q,
+ "K_idx": K_idx,
+ "qjl_signs": qjl_signs,
+ "gamma": gamma,
+ "Pi": Pi,
+ "S_mat": S_mat,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S_seq,
+ "D": D,
+ "C": C,
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ device = "cuda"
+ B, S, D, C = 2, 3, 2, 4
+
+ Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device, dtype=torch.float32)
+ K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device=device, dtype=torch.uint8)
+ # QJL signs: all +1 for simplicity
+ qjl_signs = torch.ones(S, D, device=device, dtype=torch.int8)
+ # gamma = 0: no QJL correction (reduces to MSE-only for this example)
+ gamma = torch.zeros(S, device=device, dtype=torch.float32)
+ Pi = torch.eye(D, device=device, dtype=torch.float32)
+ S_mat = torch.eye(D, device=device, dtype=torch.float32)
+ codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device=device, dtype=torch.float32)
+ scores = torch.zeros(B, S, device=device, dtype=torch.float32)
+
+ return {
+ "Q": Q,
+ "K_idx": K_idx,
+ "qjl_signs": qjl_signs,
+ "gamma": gamma,
+ "Pi": Pi,
+ "S_mat": S_mat,
+ "codebook": codebook,
+ "scores": scores,
+ "B": B,
+ "S": S,
+ "D": D,
+ "C": C,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ tests = []
+
+ # Edge: single query, single key, D=1
+ tests.append(self._make_test_case(1, 1, 1, 2, seed=1))
+
+ # Edge: zero query
+ tests.append(self._make_test_case(2, 3, 4, 4, zero_q=True, seed=2))
+
+ # Edge: small with negative queries
+ tests.append(self._make_test_case(2, 4, 4, 4, seed=3))
+
+ # Power-of-2: B=4, S=16, D=32, C=8
+ tests.append(self._make_test_case(4, 16, 32, 8, seed=10))
+
+ # Power-of-2: B=8, S=64, D=64, C=16
+ tests.append(self._make_test_case(8, 64, 64, 16, seed=20))
+
+ # Power-of-2: B=16, S=128, D=128, C=16
+ tests.append(self._make_test_case(16, 128, 128, 16, seed=30))
+
+ # Non-power-of-2: B=3, S=30, D=50, C=8
+ tests.append(self._make_test_case(3, 30, 50, 8, seed=40))
+
+ # Non-power-of-2: B=7, S=255, D=100, C=16
+ tests.append(self._make_test_case(7, 255, 100, 16, seed=50))
+
+ # Realistic: B=16, S=4096, D=128, C=16
+ tests.append(self._make_test_case(16, 4096, 128, 16, seed=60))
+
+ # Realistic: B=32, S=8192, D=128, C=8
+ tests.append(self._make_test_case(32, 8192, 128, 8, seed=70))
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ return self._make_test_case(32, 32768, 128, 16, seed=0)
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.cu b/challenges/hard/83_turboquant_attention/starter/starter.cu
new file mode 100644
index 00000000..18478b0c
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.cu
@@ -0,0 +1,6 @@
+#include
+
+// Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers
+extern "C" void solve(const float* Q, const unsigned char* K_idx, const signed char* qjl_signs,
+ const float* gamma, const float* Pi, const float* S_mat,
+ const float* codebook, float* scores, int B, int S, int D, int C) {}
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.cute.py b/challenges/hard/83_turboquant_attention/starter/starter.cute.py
new file mode 100644
index 00000000..27a31a80
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.cute.py
@@ -0,0 +1,21 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
+@cute.jit
+def solve(
+ Q: cute.Tensor,
+ K_idx: cute.Tensor,
+ qjl_signs: cute.Tensor,
+ gamma: cute.Tensor,
+ Pi: cute.Tensor,
+ S_mat: cute.Tensor,
+ codebook: cute.Tensor,
+ scores: cute.Tensor,
+ B: cute.Int32,
+ S: cute.Int32,
+ D: cute.Int32,
+ C: cute.Int32,
+):
+ pass
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.jax.py b/challenges/hard/83_turboquant_attention/starter/starter.jax.py
new file mode 100644
index 00000000..b487a049
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.jax.py
@@ -0,0 +1,21 @@
+import jax
+import jax.numpy as jnp
+
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook are tensors on GPU
+@jax.jit
+def solve(
+ Q: jax.Array,
+ K_idx: jax.Array,
+ qjl_signs: jax.Array,
+ gamma: jax.Array,
+ Pi: jax.Array,
+ S_mat: jax.Array,
+ codebook: jax.Array,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.mojo b/challenges/hard/83_turboquant_attention/starter/starter.mojo
new file mode 100644
index 00000000..43d74b27
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.mojo
@@ -0,0 +1,9 @@
+from gpu.host import DeviceContext
+from gpu.id import block_dim, block_idx, thread_idx
+from memory import UnsafePointer
+from math import ceildiv
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers
+@export
+def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], qjl_signs: UnsafePointer[Int8], gamma: UnsafePointer[Float32], Pi: UnsafePointer[Float32], S_mat: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32):
+ pass
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py b/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py
new file mode 100644
index 00000000..8b706f88
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py
@@ -0,0 +1,19 @@
+import torch
+
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ qjl_signs: torch.Tensor,
+ gamma: torch.Tensor,
+ Pi: torch.Tensor,
+ S_mat: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+):
+ pass
diff --git a/challenges/hard/83_turboquant_attention/starter/starter.triton.py b/challenges/hard/83_turboquant_attention/starter/starter.triton.py
new file mode 100644
index 00000000..ccf5176c
--- /dev/null
+++ b/challenges/hard/83_turboquant_attention/starter/starter.triton.py
@@ -0,0 +1,21 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU
+def solve(
+ Q: torch.Tensor,
+ K_idx: torch.Tensor,
+ qjl_signs: torch.Tensor,
+ gamma: torch.Tensor,
+ Pi: torch.Tensor,
+ S_mat: torch.Tensor,
+ codebook: torch.Tensor,
+ scores: torch.Tensor,
+ B: int,
+ S: int,
+ D: int,
+ C: int,
+):
+ pass