From 99910fbd3a59624fd70d331aeb34a8b05147e782 Mon Sep 17 00:00:00 2001 From: Kunal Mansukhani Date: Thu, 26 Mar 2026 19:23:51 -0400 Subject: [PATCH 1/5] Add turbo quant attention --- .../83_turboquant_attention/challenge.html | 101 ++++++ .../83_turboquant_attention/challenge.py | 308 ++++++++++++++++++ .../starter/starter.cu | 5 + .../starter/starter.cute.py | 18 + .../starter/starter.jax.py | 18 + .../starter/starter.mojo | 9 + .../starter/starter.pytorch.py | 16 + .../starter/starter.triton.py | 18 + 8 files changed, 493 insertions(+) create mode 100644 challenges/medium/83_turboquant_attention/challenge.html create mode 100644 challenges/medium/83_turboquant_attention/challenge.py create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.cu create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.cute.py create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.jax.py create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.mojo create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.pytorch.py create mode 100644 challenges/medium/83_turboquant_attention/starter/starter.triton.py diff --git a/challenges/medium/83_turboquant_attention/challenge.html b/challenges/medium/83_turboquant_attention/challenge.html new file mode 100644 index 00000000..37433e10 --- /dev/null +++ b/challenges/medium/83_turboquant_attention/challenge.html @@ -0,0 +1,101 @@ +

+ Implement attention score computation against a TurboQuant-compressed KV cache. In modern LLM inference, key vectors are compressed using TurboQuant: each key is first rotated by an orthogonal matrix \(\Pi\), then each coordinate is scalar-quantized to the nearest centroid in a codebook, storing only the uint8 index. Given B query vectors of dimension D, S quantized key vectors (stored as codebook indices), the rotation matrix \(\Pi\), and a sorted codebook of C centroids, compute the attention scores matrix. +

+

+ For each query \(Q_i\) and quantized key \(K_j\), compute: + \[ + \text{scores}_{i,j} = Q_i \cdot \text{dequant}(K_j) + \] + where dequantization reconstructs the key by looking up centroids and rotating back: + \[ + \text{dequant}(K_j) = [\text{codebook}[K_j[0]],\; \text{codebook}[K_j[1]],\; \ldots,\; \text{codebook}[K_j[D{-}1]]] \;\cdot\; \Pi + \] +

+ + + + Dequantization + Dot Product Pipeline + + K_idx + S x D (uint8) + + + Lookup + codebook[idx] + + + Rotate back + K_deq = ... * Pi + + + Dot product + Q * K_deq^T + + scores + Hint: rotating queries once avoids per-key rotation + + + + + + + +

Implementation Requirements

+ + +

Example

+

+Input:
+Query matrix \(Q\) (\(B=2, D=2\)): +\[ +\begin{bmatrix} +1.0 & 0.0 \\ +0.0 & 1.0 +\end{bmatrix} +\] +Quantized key indices \(K_\text{idx}\) (\(S=3, D=2\), uint8): +\[ +\begin{bmatrix} +0 & 3 \\ +1 & 2 \\ +3 & 0 +\end{bmatrix} +\] +Rotation matrix \(\Pi = I_{2 \times 2}\) (identity)
+Codebook (\(C=4\)): \([-0.75,\; -0.25,\; 0.25,\; 0.75]\) +

+

+Dequantized keys: look up centroids for each index, then multiply by \(\Pi\) (identity here): +\[ +K_\text{deq} = \begin{bmatrix} +-0.75 & 0.75 \\ +-0.25 & 0.25 \\ +0.75 & -0.75 +\end{bmatrix} +\] +Output scores (\(B=2, S=3\)): +\[ +\text{scores} = Q \cdot K_\text{deq}^T = \begin{bmatrix} +-0.75 & -0.25 & 0.75 \\ +0.75 & 0.25 & -0.75 +\end{bmatrix} +\] +

+ +

Constraints

+ diff --git a/challenges/medium/83_turboquant_attention/challenge.py b/challenges/medium/83_turboquant_attention/challenge.py new file mode 100644 index 00000000..f6188e6c --- /dev/null +++ b/challenges/medium/83_turboquant_attention/challenge.py @@ -0,0 +1,308 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="TurboQuant KV Cache Attention", + atol=1e-3, + rtol=1e-3, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K_idx: torch.Tensor, + Pi: torch.Tensor, + codebook: torch.Tensor, + scores: torch.Tensor, + B: int, + S: int, + D: int, + C: int, + ): + assert Q.shape == (B, D) + assert K_idx.shape == (S, D) + assert Pi.shape == (D, D) + assert codebook.shape == (C,) + assert scores.shape == (B, S) + assert Q.dtype == torch.float32 + assert K_idx.dtype == torch.uint8 + assert Pi.dtype == torch.float32 + assert codebook.dtype == torch.float32 + assert scores.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K_idx.device.type == "cuda" + assert Pi.device.type == "cuda" + assert codebook.device.type == "cuda" + assert scores.device.type == "cuda" + + # Dequantize keys: lookup centroids then rotate back + K_centroids = codebook[K_idx.long()] # S x D + K_deq = K_centroids @ Pi # S x D + + # Compute attention scores + scores.copy_(Q @ K_deq.T) # B x S + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"), + "Pi": (ctypes.POINTER(ctypes.c_float), "in"), + "codebook": (ctypes.POINTER(ctypes.c_float), "in"), + "scores": (ctypes.POINTER(ctypes.c_float), "out"), + "B": (ctypes.c_int, "in"), + "S": (ctypes.c_int, "in"), + "D": (ctypes.c_int, "in"), + "C": (ctypes.c_int, "in"), + } + + def _make_rotation(self, D): + G = torch.randn(D, D, device="cuda") + Q, _ = torch.linalg.qr(G) + return Q + + def _make_codebook(self, C, scale=1.0): + return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32) + + def generate_example_test(self) -> Dict[str, Any]: + B, S, D, C = 2, 3, 2, 4 + Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device="cuda", dtype=torch.float32) + K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device="cuda", dtype=torch.uint8) + Pi = torch.eye(D, device="cuda", dtype=torch.float32) + codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device="cuda", dtype=torch.float32) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + return { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + tests = [] + + # Edge case: single query, single key, D=1, C=2 + B, S, D, C = 1, 1, 1, 2 + Q = torch.tensor([[0.5]], device="cuda", dtype=torch.float32) + K_idx = torch.tensor([[1]], device="cuda", dtype=torch.uint8) + Pi = torch.eye(D, device="cuda", dtype=torch.float32) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Edge case: zeros query + B, S, D, C = 2, 3, 4, 4 + Q = torch.zeros(B, D, device="cuda", dtype=torch.float32) + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Edge case: negative query values + B, S, D, C = 2, 4, 4, 4 + Q = torch.tensor( + [[-0.5, -0.3, -0.8, -0.1], [-1.0, -0.5, -0.2, -0.9]], + device="cuda", + dtype=torch.float32, + ) + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Power-of-2: B=4, S=16, D=32, C=8 + B, S, D, C = 4, 16, 32, 8 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.5 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C, scale=1.5) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Power-of-2: B=8, S=64, D=64, C=16 + B, S, D, C = 8, 64, 64, 16 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Power-of-2: B=16, S=128, D=128, C=16 + B, S, D, C = 16, 128, 128, 16 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Non-power-of-2: B=3, S=30, D=50, C=8 + B, S, D, C = 3, 30, 50, 8 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.4 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Non-power-of-2: B=7, S=255, D=100, C=16 + B, S, D, C = 7, 255, 100, 16 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.6 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C, scale=1.5) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + # Realistic: B=16, S=4096, D=128, C=16 + B, S, D, C = 16, 4096, 128, 16 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + tests.append( + { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + ) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + B, S, D, C = 32, 32768, 128, 16 + Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 + K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) + Pi = self._make_rotation(D) + codebook = self._make_codebook(C) + scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) + return { + "Q": Q, + "K_idx": K_idx, + "Pi": Pi, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cu b/challenges/medium/83_turboquant_attention/starter/starter.cu new file mode 100644 index 00000000..be318f3c --- /dev/null +++ b/challenges/medium/83_turboquant_attention/starter/starter.cu @@ -0,0 +1,5 @@ +#include + +// Q, K_idx, Pi, codebook, scores are device pointers +extern "C" void solve(const float* Q, const unsigned char* K_idx, const float* Pi, + const float* codebook, float* scores, int B, int S, int D, int C) {} diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cute.py b/challenges/medium/83_turboquant_attention/starter/starter.cute.py new file mode 100644 index 00000000..ebd728fa --- /dev/null +++ b/challenges/medium/83_turboquant_attention/starter/starter.cute.py @@ -0,0 +1,18 @@ +import cutlass +import cutlass.cute as cute + + +# Q, K_idx, Pi, codebook, scores are tensors on the GPU +@cute.jit +def solve( + Q: cute.Tensor, + K_idx: cute.Tensor, + Pi: cute.Tensor, + codebook: cute.Tensor, + scores: cute.Tensor, + B: cute.Int32, + S: cute.Int32, + D: cute.Int32, + C: cute.Int32, +): + pass diff --git a/challenges/medium/83_turboquant_attention/starter/starter.jax.py b/challenges/medium/83_turboquant_attention/starter/starter.jax.py new file mode 100644 index 00000000..3c38dbb5 --- /dev/null +++ b/challenges/medium/83_turboquant_attention/starter/starter.jax.py @@ -0,0 +1,18 @@ +import jax +import jax.numpy as jnp + + +# Q, K_idx, Pi, codebook are tensors on GPU +@jax.jit +def solve( + Q: jax.Array, + K_idx: jax.Array, + Pi: jax.Array, + codebook: jax.Array, + B: int, + S: int, + D: int, + C: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/83_turboquant_attention/starter/starter.mojo b/challenges/medium/83_turboquant_attention/starter/starter.mojo new file mode 100644 index 00000000..353883cd --- /dev/null +++ b/challenges/medium/83_turboquant_attention/starter/starter.mojo @@ -0,0 +1,9 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K_idx, Pi, codebook, scores are device pointers +@export +def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], Pi: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32): + pass diff --git a/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py b/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py new file mode 100644 index 00000000..a78ee78c --- /dev/null +++ b/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py @@ -0,0 +1,16 @@ +import torch + + +# Q, K_idx, Pi, codebook, scores are tensors on the GPU +def solve( + Q: torch.Tensor, + K_idx: torch.Tensor, + Pi: torch.Tensor, + codebook: torch.Tensor, + scores: torch.Tensor, + B: int, + S: int, + D: int, + C: int, +): + pass diff --git a/challenges/medium/83_turboquant_attention/starter/starter.triton.py b/challenges/medium/83_turboquant_attention/starter/starter.triton.py new file mode 100644 index 00000000..f21c5eea --- /dev/null +++ b/challenges/medium/83_turboquant_attention/starter/starter.triton.py @@ -0,0 +1,18 @@ +import torch +import triton +import triton.language as tl + + +# q, k_idx, pi, codebook, scores are tensors on the GPU +def solve( + q: torch.Tensor, + k_idx: torch.Tensor, + pi: torch.Tensor, + codebook: torch.Tensor, + scores: torch.Tensor, + B: int, + S: int, + D: int, + C: int, +): + pass From 89536ca7e866414da2104beae01b6ed357872853 Mon Sep 17 00:00:00 2001 From: Kunal Mansukhani Date: Thu, 26 Mar 2026 20:37:29 -0400 Subject: [PATCH 2/5] Make spec better --- .../83_turboquant_attention/challenge.html | 65 ++++++------------- 1 file changed, 20 insertions(+), 45 deletions(-) diff --git a/challenges/medium/83_turboquant_attention/challenge.html b/challenges/medium/83_turboquant_attention/challenge.html index 37433e10..91f88bd3 100644 --- a/challenges/medium/83_turboquant_attention/challenge.html +++ b/challenges/medium/83_turboquant_attention/challenge.html @@ -1,52 +1,24 @@

- Implement attention score computation against a TurboQuant-compressed KV cache. In modern LLM inference, key vectors are compressed using TurboQuant: each key is first rotated by an orthogonal matrix \(\Pi\), then each coordinate is scalar-quantized to the nearest centroid in a codebook, storing only the uint8 index. Given B query vectors of dimension D, S quantized key vectors (stored as codebook indices), the rotation matrix \(\Pi\), and a sorted codebook of C centroids, compute the attention scores matrix. + Implement attention score computation against a quantized KV cache. During LLM inference, the KV cache can dominate memory. TurboQuant addresses this by compressing each key vector down to uint8 codebook indices, reducing memory by up to 4×. Your task is to compute attention scores between full-precision queries and these compressed keys.

+

- For each query \(Q_i\) and quantized key \(K_j\), compute: - \[ - \text{scores}_{i,j} = Q_i \cdot \text{dequant}(K_j) - \] - where dequantization reconstructs the key by looking up centroids and rotating back: - \[ - \text{dequant}(K_j) = [\text{codebook}[K_j[0]],\; \text{codebook}[K_j[1]],\; \ldots,\; \text{codebook}[K_j[D{-}1]]] \;\cdot\; \Pi - \] + How TurboQuant encodes keys (already done for you): each key vector is multiplied by a random orthogonal matrix \(\Pi\), then each coordinate is replaced by the index of its nearest centroid in a codebook.

- - - Dequantization + Dot Product Pipeline - - K_idx - S x D (uint8) - - - Lookup - codebook[idx] - - - Rotate back - K_deq = ... * Pi - - - Dot product - Q * K_deq^T - - scores - Hint: rotating queries once avoids per-key rotation - - - - - - +

+ What you need to compute: given queries \(Q\) and quantized key indices \(K_\text{idx}\), dequantize each key and compute dot products: +

+
    +
  1. Lookup: replace each index with its codebook value, producing a centroid vector \(\tilde{Y}_j\)
  2. +
  3. Rotate back: recover the approximate key \(\tilde{K}_j = \tilde{Y}_j \cdot \Pi\)
  4. +
  5. Dot product: \(\text{scores}_{i,j} = Q_i \cdot \tilde{K}_j\)
  6. +

Implementation Requirements

  • The solve function signature must remain unchanged.
  • Use only native features (no external libraries).
  • -
  • \(\Pi\) is an orthogonal matrix (\(\Pi^T \Pi = I\)), stored row-major as float32.
  • -
  • The codebook is sorted in ascending order (float32).
  • -
  • K_idx contains uint8 indices in \([0, C)\).
  • Store the result in scores as float32.
@@ -72,17 +44,18 @@

Example

Codebook (\(C=4\)): \([-0.75,\; -0.25,\; 0.25,\; 0.75]\)

-Dequantized keys: look up centroids for each index, then multiply by \(\Pi\) (identity here): +Step 1: Lookup centroids for each index: \[ -K_\text{deq} = \begin{bmatrix} +\tilde{Y} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix} \] -Output scores (\(B=2, S=3\)): +Step 2: Rotate back (\(\Pi\) is identity here, so \(\tilde{K} = \tilde{Y}\)).
+Step 3: Dot products: \[ -\text{scores} = Q \cdot K_\text{deq}^T = \begin{bmatrix} +\text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix} @@ -95,7 +68,9 @@

Constraints

  • 1 ≤ S ≤ 65,536
  • 1 ≤ D ≤ 256
  • 2 ≤ C ≤ 256
  • -
  • \(\Pi\) is an orthogonal matrix, codebook is sorted ascending
  • -
  • All floating-point values are float32; key indices are uint8
  • +
  • \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
  • +
  • Codebook values are sorted in ascending order
  • +
  • K_idx values are in \([0, C)\)
  • +
  • All floating-point inputs are float32; key indices are uint8
  • Performance is measured with B = 32, S = 32,768, D = 128, C = 16
  • From 33db15a785e2d0dfbf419d487bdc01f815760ca9 Mon Sep 17 00:00:00 2001 From: James Song Date: Fri, 27 Mar 2026 18:19:17 -0400 Subject: [PATCH 3/5] update turboquant to include residual correction --- .../83_turboquant_attention/challenge.html | 202 ++++++++++++ .../hard/83_turboquant_attention/challenge.py | 222 +++++++++++++ .../starter/starter.cu | 6 + .../starter/starter.cute.py | 5 +- .../starter/starter.jax.py | 5 +- .../starter/starter.mojo | 9 + .../starter/starter.pytorch.py | 5 +- .../starter/starter.triton.py | 21 ++ .../83_turboquant_attention/challenge.html | 76 ----- .../83_turboquant_attention/challenge.py | 308 ------------------ .../starter/starter.cu | 5 - .../starter/starter.mojo | 9 - .../starter/starter.triton.py | 18 - 13 files changed, 472 insertions(+), 419 deletions(-) create mode 100644 challenges/hard/83_turboquant_attention/challenge.html create mode 100644 challenges/hard/83_turboquant_attention/challenge.py create mode 100644 challenges/hard/83_turboquant_attention/starter/starter.cu rename challenges/{medium => hard}/83_turboquant_attention/starter/starter.cute.py (63%) rename challenges/{medium => hard}/83_turboquant_attention/starter/starter.jax.py (64%) create mode 100644 challenges/hard/83_turboquant_attention/starter/starter.mojo rename challenges/{medium => hard}/83_turboquant_attention/starter/starter.pytorch.py (56%) create mode 100644 challenges/hard/83_turboquant_attention/starter/starter.triton.py delete mode 100644 challenges/medium/83_turboquant_attention/challenge.html delete mode 100644 challenges/medium/83_turboquant_attention/challenge.py delete mode 100644 challenges/medium/83_turboquant_attention/starter/starter.cu delete mode 100644 challenges/medium/83_turboquant_attention/starter/starter.mojo delete mode 100644 challenges/medium/83_turboquant_attention/starter/starter.triton.py diff --git a/challenges/hard/83_turboquant_attention/challenge.html b/challenges/hard/83_turboquant_attention/challenge.html new file mode 100644 index 00000000..ff34a028 --- /dev/null +++ b/challenges/hard/83_turboquant_attention/challenge.html @@ -0,0 +1,202 @@ +

    + Implement attention score computation against a + TurboQuant-compressed + KV cache. TurboQuant compresses each key vector to uint8 codebook indices plus a 1-bit + residual correction (QJL), reducing KV cache memory by up to 6x. Your task: dequantize the + compressed keys and compute dot-product attention scores against full-precision queries. +

    + + + + + + TurboQuant Dequantization Pipeline (per key vector) + + + + + Stage 1: MSE + + + + K_idx + [S, D] uint8 + + + + + + + codebook[ K_idx ] + centroid lookup + + + + + + + × Π + rotate back + + + = + + + + K̃_mse + [S, D] float32 + + + + + Stage 2: QJL residual + + + + σ + [S,D] ±1 + + + + + + σ · M + project + + + + + + × √(π/2)/D × γ + scale by norm + + + = + + + + K̃_res + [S, D] float32 + + + + + Combine + Score + + + + K̃_mse + + + + + + K̃_res + + = + + + + + + then: + + + Q + + · + + + K̃ᵀ + + = + + + scores + + + Π = orthogonal rotation [D×D] + M = Gaussian projection [D×D] + σ = sign bits ±1, γ = ‖residual‖₂ + + +

    + Background — how the keys were compressed (already done for you, not part of the challenge): +

    +
      +
    1. Rotate: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each + coordinate follow a Beta distribution, so a single fixed codebook works for all coordinates.
    2. +
    3. Scalar quantize: replace each coordinate of \(y\) with the index of its nearest + codebook centroid \(\rightarrow K_\text{idx}\) (uint8).
    4. +
    5. Residual correction: MSE quantization loses information. Compute the residual + \(r = K - \tilde{K}_\text{mse}\), then store: +
        +
      • \(\sigma = \text{sign}(M \cdot r) \in \{-1,+1\}^D\) — direction (int8)
      • +
      • \(\gamma = \|r\|_2\) — magnitude (float32 scalar per key)
      • +
      + where \(M \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix (S_mat in code). +
    6. +
    + +

    + What you compute — dequantize and score: +

    +
      +
    1. MSE dequantize: look up centroids, undo the rotation: + \[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]
    2. +
    3. Residual dequantize: reconstruct the residual correction: + \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot M\] + The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.
    4. +
    5. Combine: + \(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)
    6. +
    7. Dot product: + \(\text{scores}_{b,s} = Q_b \cdot \tilde{K}_s\)
    8. +
    +

    + The residual correction makes the inner product unbiased: + \(\mathbb{E}[\langle Q, \tilde{K} \rangle] = \langle Q, K \rangle\). +

    + +

    Implementation Requirements

    +
      +
    • The solve function signature must remain unchanged.
    • +
    • Use only native features (no external libraries).
    • +
    • Store the result in scores as float32.
    • +
    + +

    Example

    +

    + Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(M = I\), \(\gamma = \mathbf{0}\) (residual correction disabled): +

    +

    + \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), \quad + \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), \quad + codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\) +

    +

    + Step 1 — MSE lookup and rotate back (\(\Pi = I\)): + \[ + \tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix} + \] + Step 2 — Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\). +

    +

    + Output: + \[ + \text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} -0.75 & -0.25 & 0.75 \\ 0.75 & 0.25 & -0.75 \end{bmatrix} + \] +

    + +

    Constraints

    +
      +
    • 1 ≤ B ≤ 32
    • +
    • 1 ≤ S ≤ 65,536
    • +
    • 1 ≤ D ≤ 256
    • +
    • 2 ≤ C ≤ 256
    • +
    • \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
    • +
    • \(M\) (S_mat) has i.i.d. \(\mathcal{N}(0,1)\) entries
    • +
    • qjl_signs (\(\sigma\)) values are in \(\{-1, +1\}\) (int8)
    • +
    • K_idx values are in \([0, C)\) (uint8)
    • +
    • All floating-point inputs are float32
    • +
    • Performance is measured with B = 32, S = 32,768, D = 128, C = 16
    • +
    diff --git a/challenges/hard/83_turboquant_attention/challenge.py b/challenges/hard/83_turboquant_attention/challenge.py new file mode 100644 index 00000000..e6723af1 --- /dev/null +++ b/challenges/hard/83_turboquant_attention/challenge.py @@ -0,0 +1,222 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="TurboQuant KV Cache Attention", + atol=1e-3, + rtol=1e-3, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + Q: torch.Tensor, + K_idx: torch.Tensor, + qjl_signs: torch.Tensor, + gamma: torch.Tensor, + Pi: torch.Tensor, + S_mat: torch.Tensor, + codebook: torch.Tensor, + scores: torch.Tensor, + B: int, + S: int, + D: int, + C: int, + ): + assert Q.shape == (B, D) + assert K_idx.shape == (S, D) + assert qjl_signs.shape == (S, D) + assert gamma.shape == (S,) + assert Pi.shape == (D, D) + assert S_mat.shape == (D, D) + assert codebook.shape == (C,) + assert scores.shape == (B, S) + assert Q.dtype == torch.float32 + assert K_idx.dtype == torch.uint8 + assert qjl_signs.dtype == torch.int8 + assert gamma.dtype == torch.float32 + assert Pi.dtype == torch.float32 + assert S_mat.dtype == torch.float32 + assert codebook.dtype == torch.float32 + assert scores.dtype == torch.float32 + assert Q.device.type == "cuda" + assert K_idx.device.type == "cuda" + assert qjl_signs.device.type == "cuda" + assert gamma.device.type == "cuda" + assert Pi.device.type == "cuda" + assert S_mat.device.type == "cuda" + assert codebook.device.type == "cuda" + assert scores.device.type == "cuda" + + # Stage 1: MSE dequantization — lookup centroids, rotate back + K_centroids = codebook[K_idx.long()] # [S, D] + K_mse = K_centroids @ Pi # [S, D] (row convention: ỹ @ Π = Π^T · ỹ) + + # Stage 2: QJL dequantization — reconstruct residual correction + scale = math.sqrt(math.pi / 2.0) / D + K_qjl = scale * gamma.unsqueeze(1) * (qjl_signs.float() @ S_mat) # [S, D] + + # Combined dequantization + K_deq = K_mse + K_qjl # [S, D] + + # Attention scores + scores.copy_(Q @ K_deq.T) # [B, S] + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "Q": (ctypes.POINTER(ctypes.c_float), "in"), + "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"), + "qjl_signs": (ctypes.POINTER(ctypes.c_int8), "in"), + "gamma": (ctypes.POINTER(ctypes.c_float), "in"), + "Pi": (ctypes.POINTER(ctypes.c_float), "in"), + "S_mat": (ctypes.POINTER(ctypes.c_float), "in"), + "codebook": (ctypes.POINTER(ctypes.c_float), "in"), + "scores": (ctypes.POINTER(ctypes.c_float), "out"), + "B": (ctypes.c_int, "in"), + "S": (ctypes.c_int, "in"), + "D": (ctypes.c_int, "in"), + "C": (ctypes.c_int, "in"), + } + + def _make_rotation(self, D): + G = torch.randn(D, D, device="cuda") + Q, _ = torch.linalg.qr(G) + return Q + + def _make_codebook(self, C, scale=1.0): + return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32) + + def _encode_keys(self, K, Pi, S_mat, codebook): + """Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual.""" + S, D = K.shape + C = codebook.shape[0] + + # Stage 1: MSE encoding + Y = K @ Pi.T # rotate into quantization space + # Scalar quantize each coordinate to nearest centroid + diffs = Y.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0) # [S, D, C] + K_idx = diffs.abs().argmin(dim=-1).to(torch.uint8) # [S, D] + + # MSE dequantization (to compute residual) + K_centroids = codebook[K_idx.long()] # [S, D] + K_mse = K_centroids @ Pi # [S, D] + + # Stage 2: QJL encoding of residual + residual = K - K_mse # [S, D] + gamma = residual.norm(dim=1) # [S] + proj = residual @ S_mat.T # [S, D] (row convention for S · r) + qjl_signs = torch.sign(proj).to(torch.int8) # [S, D] + # Ensure no zeros (sign(0)=0 → map to +1) + qjl_signs[qjl_signs == 0] = 1 + + return K_idx, qjl_signs, gamma + + def _make_test_case(self, B, S_seq, D, C, zero_q=False, seed=42): + torch.manual_seed(seed) + device = "cuda" + + Pi = self._make_rotation(D) + S_mat = torch.randn(D, D, device=device, dtype=torch.float32) + codebook = self._make_codebook(C) + + if zero_q: + Q = torch.zeros(B, D, device=device, dtype=torch.float32) + else: + Q = torch.randn(B, D, device=device, dtype=torch.float32) * 0.5 + + # Generate realistic keys and encode them + K = torch.randn(S_seq, D, device=device, dtype=torch.float32) * 0.3 + K_idx, qjl_signs, gamma = self._encode_keys(K, Pi, S_mat, codebook) + + scores = torch.zeros(B, S_seq, device=device, dtype=torch.float32) + + return { + "Q": Q, + "K_idx": K_idx, + "qjl_signs": qjl_signs, + "gamma": gamma, + "Pi": Pi, + "S_mat": S_mat, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S_seq, + "D": D, + "C": C, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = "cuda" + B, S, D, C = 2, 3, 2, 4 + + Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device=device, dtype=torch.float32) + K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device=device, dtype=torch.uint8) + # QJL signs: all +1 for simplicity + qjl_signs = torch.ones(S, D, device=device, dtype=torch.int8) + # gamma = 0: no QJL correction (reduces to MSE-only for this example) + gamma = torch.zeros(S, device=device, dtype=torch.float32) + Pi = torch.eye(D, device=device, dtype=torch.float32) + S_mat = torch.eye(D, device=device, dtype=torch.float32) + codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device=device, dtype=torch.float32) + scores = torch.zeros(B, S, device=device, dtype=torch.float32) + + return { + "Q": Q, + "K_idx": K_idx, + "qjl_signs": qjl_signs, + "gamma": gamma, + "Pi": Pi, + "S_mat": S_mat, + "codebook": codebook, + "scores": scores, + "B": B, + "S": S, + "D": D, + "C": C, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + tests = [] + + # Edge: single query, single key, D=1 + tests.append(self._make_test_case(1, 1, 1, 2, seed=1)) + + # Edge: zero query + tests.append(self._make_test_case(2, 3, 4, 4, zero_q=True, seed=2)) + + # Edge: small with negative queries + tests.append(self._make_test_case(2, 4, 4, 4, seed=3)) + + # Power-of-2: B=4, S=16, D=32, C=8 + tests.append(self._make_test_case(4, 16, 32, 8, seed=10)) + + # Power-of-2: B=8, S=64, D=64, C=16 + tests.append(self._make_test_case(8, 64, 64, 16, seed=20)) + + # Power-of-2: B=16, S=128, D=128, C=16 + tests.append(self._make_test_case(16, 128, 128, 16, seed=30)) + + # Non-power-of-2: B=3, S=30, D=50, C=8 + tests.append(self._make_test_case(3, 30, 50, 8, seed=40)) + + # Non-power-of-2: B=7, S=255, D=100, C=16 + tests.append(self._make_test_case(7, 255, 100, 16, seed=50)) + + # Realistic: B=16, S=4096, D=128, C=16 + tests.append(self._make_test_case(16, 4096, 128, 16, seed=60)) + + # Realistic: B=32, S=8192, D=128, C=8 + tests.append(self._make_test_case(32, 8192, 128, 8, seed=70)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + return self._make_test_case(32, 32768, 128, 16, seed=0) diff --git a/challenges/hard/83_turboquant_attention/starter/starter.cu b/challenges/hard/83_turboquant_attention/starter/starter.cu new file mode 100644 index 00000000..18478b0c --- /dev/null +++ b/challenges/hard/83_turboquant_attention/starter/starter.cu @@ -0,0 +1,6 @@ +#include + +// Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers +extern "C" void solve(const float* Q, const unsigned char* K_idx, const signed char* qjl_signs, + const float* gamma, const float* Pi, const float* S_mat, + const float* codebook, float* scores, int B, int S, int D, int C) {} diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cute.py b/challenges/hard/83_turboquant_attention/starter/starter.cute.py similarity index 63% rename from challenges/medium/83_turboquant_attention/starter/starter.cute.py rename to challenges/hard/83_turboquant_attention/starter/starter.cute.py index ebd728fa..27a31a80 100644 --- a/challenges/medium/83_turboquant_attention/starter/starter.cute.py +++ b/challenges/hard/83_turboquant_attention/starter/starter.cute.py @@ -2,12 +2,15 @@ import cutlass.cute as cute -# Q, K_idx, Pi, codebook, scores are tensors on the GPU +# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU @cute.jit def solve( Q: cute.Tensor, K_idx: cute.Tensor, + qjl_signs: cute.Tensor, + gamma: cute.Tensor, Pi: cute.Tensor, + S_mat: cute.Tensor, codebook: cute.Tensor, scores: cute.Tensor, B: cute.Int32, diff --git a/challenges/medium/83_turboquant_attention/starter/starter.jax.py b/challenges/hard/83_turboquant_attention/starter/starter.jax.py similarity index 64% rename from challenges/medium/83_turboquant_attention/starter/starter.jax.py rename to challenges/hard/83_turboquant_attention/starter/starter.jax.py index 3c38dbb5..b487a049 100644 --- a/challenges/medium/83_turboquant_attention/starter/starter.jax.py +++ b/challenges/hard/83_turboquant_attention/starter/starter.jax.py @@ -2,12 +2,15 @@ import jax.numpy as jnp -# Q, K_idx, Pi, codebook are tensors on GPU +# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook are tensors on GPU @jax.jit def solve( Q: jax.Array, K_idx: jax.Array, + qjl_signs: jax.Array, + gamma: jax.Array, Pi: jax.Array, + S_mat: jax.Array, codebook: jax.Array, B: int, S: int, diff --git a/challenges/hard/83_turboquant_attention/starter/starter.mojo b/challenges/hard/83_turboquant_attention/starter/starter.mojo new file mode 100644 index 00000000..43d74b27 --- /dev/null +++ b/challenges/hard/83_turboquant_attention/starter/starter.mojo @@ -0,0 +1,9 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are device pointers +@export +def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], qjl_signs: UnsafePointer[Int8], gamma: UnsafePointer[Float32], Pi: UnsafePointer[Float32], S_mat: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32): + pass diff --git a/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py b/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py similarity index 56% rename from challenges/medium/83_turboquant_attention/starter/starter.pytorch.py rename to challenges/hard/83_turboquant_attention/starter/starter.pytorch.py index a78ee78c..8b706f88 100644 --- a/challenges/medium/83_turboquant_attention/starter/starter.pytorch.py +++ b/challenges/hard/83_turboquant_attention/starter/starter.pytorch.py @@ -1,11 +1,14 @@ import torch -# Q, K_idx, Pi, codebook, scores are tensors on the GPU +# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU def solve( Q: torch.Tensor, K_idx: torch.Tensor, + qjl_signs: torch.Tensor, + gamma: torch.Tensor, Pi: torch.Tensor, + S_mat: torch.Tensor, codebook: torch.Tensor, scores: torch.Tensor, B: int, diff --git a/challenges/hard/83_turboquant_attention/starter/starter.triton.py b/challenges/hard/83_turboquant_attention/starter/starter.triton.py new file mode 100644 index 00000000..ccf5176c --- /dev/null +++ b/challenges/hard/83_turboquant_attention/starter/starter.triton.py @@ -0,0 +1,21 @@ +import torch +import triton +import triton.language as tl + + +# Q, K_idx, qjl_signs, gamma, Pi, S_mat, codebook, scores are tensors on the GPU +def solve( + Q: torch.Tensor, + K_idx: torch.Tensor, + qjl_signs: torch.Tensor, + gamma: torch.Tensor, + Pi: torch.Tensor, + S_mat: torch.Tensor, + codebook: torch.Tensor, + scores: torch.Tensor, + B: int, + S: int, + D: int, + C: int, +): + pass diff --git a/challenges/medium/83_turboquant_attention/challenge.html b/challenges/medium/83_turboquant_attention/challenge.html deleted file mode 100644 index 91f88bd3..00000000 --- a/challenges/medium/83_turboquant_attention/challenge.html +++ /dev/null @@ -1,76 +0,0 @@ -

    - Implement attention score computation against a quantized KV cache. During LLM inference, the KV cache can dominate memory. TurboQuant addresses this by compressing each key vector down to uint8 codebook indices, reducing memory by up to 4×. Your task is to compute attention scores between full-precision queries and these compressed keys. -

    - -

    - How TurboQuant encodes keys (already done for you): each key vector is multiplied by a random orthogonal matrix \(\Pi\), then each coordinate is replaced by the index of its nearest centroid in a codebook. -

    - -

    - What you need to compute: given queries \(Q\) and quantized key indices \(K_\text{idx}\), dequantize each key and compute dot products: -

    -
      -
    1. Lookup: replace each index with its codebook value, producing a centroid vector \(\tilde{Y}_j\)
    2. -
    3. Rotate back: recover the approximate key \(\tilde{K}_j = \tilde{Y}_j \cdot \Pi\)
    4. -
    5. Dot product: \(\text{scores}_{i,j} = Q_i \cdot \tilde{K}_j\)
    6. -
    - -

    Implementation Requirements

    -
      -
    • The solve function signature must remain unchanged.
    • -
    • Use only native features (no external libraries).
    • -
    • Store the result in scores as float32.
    • -
    - -

    Example

    -

    -Input:
    -Query matrix \(Q\) (\(B=2, D=2\)): -\[ -\begin{bmatrix} -1.0 & 0.0 \\ -0.0 & 1.0 -\end{bmatrix} -\] -Quantized key indices \(K_\text{idx}\) (\(S=3, D=2\), uint8): -\[ -\begin{bmatrix} -0 & 3 \\ -1 & 2 \\ -3 & 0 -\end{bmatrix} -\] -Rotation matrix \(\Pi = I_{2 \times 2}\) (identity)
    -Codebook (\(C=4\)): \([-0.75,\; -0.25,\; 0.25,\; 0.75]\) -

    -

    -Step 1: Lookup centroids for each index: -\[ -\tilde{Y} = \begin{bmatrix} --0.75 & 0.75 \\ --0.25 & 0.25 \\ -0.75 & -0.75 -\end{bmatrix} -\] -Step 2: Rotate back (\(\Pi\) is identity here, so \(\tilde{K} = \tilde{Y}\)).
    -Step 3: Dot products: -\[ -\text{scores} = Q \cdot \tilde{K}^T = \begin{bmatrix} --0.75 & -0.25 & 0.75 \\ -0.75 & 0.25 & -0.75 -\end{bmatrix} -\] -

    - -

    Constraints

    -
      -
    • 1 ≤ B ≤ 32
    • -
    • 1 ≤ S ≤ 65,536
    • -
    • 1 ≤ D ≤ 256
    • -
    • 2 ≤ C ≤ 256
    • -
    • \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
    • -
    • Codebook values are sorted in ascending order
    • -
    • K_idx values are in \([0, C)\)
    • -
    • All floating-point inputs are float32; key indices are uint8
    • -
    • Performance is measured with B = 32, S = 32,768, D = 128, C = 16
    • -
    diff --git a/challenges/medium/83_turboquant_attention/challenge.py b/challenges/medium/83_turboquant_attention/challenge.py deleted file mode 100644 index f6188e6c..00000000 --- a/challenges/medium/83_turboquant_attention/challenge.py +++ /dev/null @@ -1,308 +0,0 @@ -import ctypes -from typing import Any, Dict, List - -import torch -from core.challenge_base import ChallengeBase - - -class Challenge(ChallengeBase): - def __init__(self): - super().__init__( - name="TurboQuant KV Cache Attention", - atol=1e-3, - rtol=1e-3, - num_gpus=1, - access_tier="free", - ) - - def reference_impl( - self, - Q: torch.Tensor, - K_idx: torch.Tensor, - Pi: torch.Tensor, - codebook: torch.Tensor, - scores: torch.Tensor, - B: int, - S: int, - D: int, - C: int, - ): - assert Q.shape == (B, D) - assert K_idx.shape == (S, D) - assert Pi.shape == (D, D) - assert codebook.shape == (C,) - assert scores.shape == (B, S) - assert Q.dtype == torch.float32 - assert K_idx.dtype == torch.uint8 - assert Pi.dtype == torch.float32 - assert codebook.dtype == torch.float32 - assert scores.dtype == torch.float32 - assert Q.device.type == "cuda" - assert K_idx.device.type == "cuda" - assert Pi.device.type == "cuda" - assert codebook.device.type == "cuda" - assert scores.device.type == "cuda" - - # Dequantize keys: lookup centroids then rotate back - K_centroids = codebook[K_idx.long()] # S x D - K_deq = K_centroids @ Pi # S x D - - # Compute attention scores - scores.copy_(Q @ K_deq.T) # B x S - - def get_solve_signature(self) -> Dict[str, tuple]: - return { - "Q": (ctypes.POINTER(ctypes.c_float), "in"), - "K_idx": (ctypes.POINTER(ctypes.c_uint8), "in"), - "Pi": (ctypes.POINTER(ctypes.c_float), "in"), - "codebook": (ctypes.POINTER(ctypes.c_float), "in"), - "scores": (ctypes.POINTER(ctypes.c_float), "out"), - "B": (ctypes.c_int, "in"), - "S": (ctypes.c_int, "in"), - "D": (ctypes.c_int, "in"), - "C": (ctypes.c_int, "in"), - } - - def _make_rotation(self, D): - G = torch.randn(D, D, device="cuda") - Q, _ = torch.linalg.qr(G) - return Q - - def _make_codebook(self, C, scale=1.0): - return torch.linspace(-scale, scale, C, device="cuda", dtype=torch.float32) - - def generate_example_test(self) -> Dict[str, Any]: - B, S, D, C = 2, 3, 2, 4 - Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]], device="cuda", dtype=torch.float32) - K_idx = torch.tensor([[0, 3], [1, 2], [3, 0]], device="cuda", dtype=torch.uint8) - Pi = torch.eye(D, device="cuda", dtype=torch.float32) - codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75], device="cuda", dtype=torch.float32) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - return { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - - def generate_functional_test(self) -> List[Dict[str, Any]]: - tests = [] - - # Edge case: single query, single key, D=1, C=2 - B, S, D, C = 1, 1, 1, 2 - Q = torch.tensor([[0.5]], device="cuda", dtype=torch.float32) - K_idx = torch.tensor([[1]], device="cuda", dtype=torch.uint8) - Pi = torch.eye(D, device="cuda", dtype=torch.float32) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Edge case: zeros query - B, S, D, C = 2, 3, 4, 4 - Q = torch.zeros(B, D, device="cuda", dtype=torch.float32) - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Edge case: negative query values - B, S, D, C = 2, 4, 4, 4 - Q = torch.tensor( - [[-0.5, -0.3, -0.8, -0.1], [-1.0, -0.5, -0.2, -0.9]], - device="cuda", - dtype=torch.float32, - ) - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Power-of-2: B=4, S=16, D=32, C=8 - B, S, D, C = 4, 16, 32, 8 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.5 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C, scale=1.5) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Power-of-2: B=8, S=64, D=64, C=16 - B, S, D, C = 8, 64, 64, 16 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Power-of-2: B=16, S=128, D=128, C=16 - B, S, D, C = 16, 128, 128, 16 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Non-power-of-2: B=3, S=30, D=50, C=8 - B, S, D, C = 3, 30, 50, 8 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.4 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Non-power-of-2: B=7, S=255, D=100, C=16 - B, S, D, C = 7, 255, 100, 16 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.6 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C, scale=1.5) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - # Realistic: B=16, S=4096, D=128, C=16 - B, S, D, C = 16, 4096, 128, 16 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - tests.append( - { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } - ) - - return tests - - def generate_performance_test(self) -> Dict[str, Any]: - B, S, D, C = 32, 32768, 128, 16 - Q = torch.randn(B, D, device="cuda", dtype=torch.float32) * 0.3 - K_idx = torch.randint(0, C, (S, D), device="cuda", dtype=torch.uint8) - Pi = self._make_rotation(D) - codebook = self._make_codebook(C) - scores = torch.zeros(B, S, device="cuda", dtype=torch.float32) - return { - "Q": Q, - "K_idx": K_idx, - "Pi": Pi, - "codebook": codebook, - "scores": scores, - "B": B, - "S": S, - "D": D, - "C": C, - } diff --git a/challenges/medium/83_turboquant_attention/starter/starter.cu b/challenges/medium/83_turboquant_attention/starter/starter.cu deleted file mode 100644 index be318f3c..00000000 --- a/challenges/medium/83_turboquant_attention/starter/starter.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include - -// Q, K_idx, Pi, codebook, scores are device pointers -extern "C" void solve(const float* Q, const unsigned char* K_idx, const float* Pi, - const float* codebook, float* scores, int B, int S, int D, int C) {} diff --git a/challenges/medium/83_turboquant_attention/starter/starter.mojo b/challenges/medium/83_turboquant_attention/starter/starter.mojo deleted file mode 100644 index 353883cd..00000000 --- a/challenges/medium/83_turboquant_attention/starter/starter.mojo +++ /dev/null @@ -1,9 +0,0 @@ -from gpu.host import DeviceContext -from gpu.id import block_dim, block_idx, thread_idx -from memory import UnsafePointer -from math import ceildiv - -# Q, K_idx, Pi, codebook, scores are device pointers -@export -def solve(Q: UnsafePointer[Float32], K_idx: UnsafePointer[UInt8], Pi: UnsafePointer[Float32], codebook: UnsafePointer[Float32], scores: UnsafePointer[Float32], B: Int32, S: Int32, D: Int32, C: Int32): - pass diff --git a/challenges/medium/83_turboquant_attention/starter/starter.triton.py b/challenges/medium/83_turboquant_attention/starter/starter.triton.py deleted file mode 100644 index f21c5eea..00000000 --- a/challenges/medium/83_turboquant_attention/starter/starter.triton.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import triton -import triton.language as tl - - -# q, k_idx, pi, codebook, scores are tensors on the GPU -def solve( - q: torch.Tensor, - k_idx: torch.Tensor, - pi: torch.Tensor, - codebook: torch.Tensor, - scores: torch.Tensor, - B: int, - S: int, - D: int, - C: int, -): - pass From bad839a975d39191a8023c269ec37e032a078159 Mon Sep 17 00:00:00 2001 From: James Song Date: Fri, 27 Mar 2026 18:22:30 -0400 Subject: [PATCH 4/5] fix lint --- challenges/hard/83_turboquant_attention/challenge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/challenges/hard/83_turboquant_attention/challenge.py b/challenges/hard/83_turboquant_attention/challenge.py index e6723af1..6a7e07b8 100644 --- a/challenges/hard/83_turboquant_attention/challenge.py +++ b/challenges/hard/83_turboquant_attention/challenge.py @@ -97,7 +97,6 @@ def _make_codebook(self, C, scale=1.0): def _encode_keys(self, K, Pi, S_mat, codebook): """Simulate TurboQuant_prod encoding: rotate, quantize, compute QJL on residual.""" S, D = K.shape - C = codebook.shape[0] # Stage 1: MSE encoding Y = K @ Pi.T # rotate into quantization space From 8b125d89c26bb305a50828153aec54bfb2d117ac Mon Sep 17 00:00:00 2001 From: Kunal Mansukhani Date: Fri, 27 Mar 2026 23:34:03 -0400 Subject: [PATCH 5/5] Clean up spec --- .../83_turboquant_attention/challenge.html | 140 ++---------------- 1 file changed, 14 insertions(+), 126 deletions(-) diff --git a/challenges/hard/83_turboquant_attention/challenge.html b/challenges/hard/83_turboquant_attention/challenge.html index ff34a028..e6397aef 100644 --- a/challenges/hard/83_turboquant_attention/challenge.html +++ b/challenges/hard/83_turboquant_attention/challenge.html @@ -6,122 +6,8 @@ compressed keys and compute dot-product attention scores against full-precision queries.

    - - - - - TurboQuant Dequantization Pipeline (per key vector) - - - - - Stage 1: MSE - - - - K_idx - [S, D] uint8 - - - - - - - codebook[ K_idx ] - centroid lookup - - - - - - - × Π - rotate back - - - = - - - - K̃_mse - [S, D] float32 - - - - - Stage 2: QJL residual - - - - σ - [S,D] ±1 - - - - - - σ · M - project - - - - - - × √(π/2)/D × γ - scale by norm - - - = - - - - K̃_res - [S, D] float32 - - - - - Combine + Score - - - - K̃_mse - - + - - - K̃_res - - = - - - - - - then: - - - Q - - · - - - K̃ᵀ - - = - - - scores - - - Π = orthogonal rotation [D×D] - M = Gaussian projection [D×D] - σ = sign bits ±1, γ = ‖residual‖₂ - -

    - Background — how the keys were compressed (already done for you, not part of the challenge): + Background - how the keys were compressed (already done for you, not part of the challenge):

    1. Rotate: multiply key by orthogonal matrix \(\Pi\): \(\;y = \Pi \cdot K\). This makes each @@ -131,21 +17,21 @@
    2. Residual correction: MSE quantization loses information. Compute the residual \(r = K - \tilde{K}_\text{mse}\), then store:
        -
      • \(\sigma = \text{sign}(M \cdot r) \in \{-1,+1\}^D\) — direction (int8)
      • -
      • \(\gamma = \|r\|_2\) — magnitude (float32 scalar per key)
      • +
      • \(\sigma = \text{sign}(S_\text{mat} \cdot r) \in \{-1,+1\}^D\) - direction (int8)
      • +
      • \(\gamma = \|r\|_2\) - magnitude (float32 scalar per key)
      - where \(M \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix (S_mat in code). + where \(S_\text{mat} \in \mathbb{R}^{D \times D}\) is a random Gaussian projection matrix.

    - What you compute — dequantize and score: + What you compute - dequantize and score:

    1. MSE dequantize: look up centroids, undo the rotation: \[\tilde{K}_\text{mse} = \text{codebook}[K_\text{idx}] \cdot \Pi\]
    2. Residual dequantize: reconstruct the residual correction: - \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot M\] + \[\tilde{K}_\text{res} = \frac{\sqrt{\pi/2}}{D} \cdot \gamma \cdot \sigma \cdot S_\text{mat}\] The \(\sqrt{\pi/2}/D\) constant corrects for the distortion introduced by taking signs.
    3. Combine: \(\tilde{K} = \tilde{K}_\text{mse} + \tilde{K}_\text{res}\)
    4. @@ -166,19 +52,20 @@

      Implementation Requirements

      Example

      - Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(M = I\), \(\gamma = \mathbf{0}\) (residual correction disabled): + Input: \(B=2,\; S=3,\; D=2,\; C=4\), with \(\Pi = I\), \(S_\text{mat} = I\), \(\gamma = \mathbf{0}\) (residual correction disabled), + \(\sigma = \mathbf{1}\) (all +1):

      - \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), \quad - \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), \quad + \(Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}\), + \(K_\text{idx} = \begin{bmatrix} 0 & 3 \\ 1 & 2 \\ 3 & 0 \end{bmatrix}\), codebook \(= [-0.75,\; -0.25,\; 0.25,\; 0.75]\)

      - Step 1 — MSE lookup and rotate back (\(\Pi = I\)): + Step 1 - MSE lookup and rotate back (\(\Pi = I\)): \[ \tilde{K}_\text{mse} = \begin{bmatrix} -0.75 & 0.75 \\ -0.25 & 0.25 \\ 0.75 & -0.75 \end{bmatrix} \] - Step 2 — Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\). + Step 2 - Residual correction is zero (\(\gamma = 0\)), so \(\tilde{K} = \tilde{K}_\text{mse}\).

      Output: @@ -194,7 +81,8 @@

      Constraints

    5. 1 ≤ D ≤ 256
    6. 2 ≤ C ≤ 256
    7. \(\Pi\) is orthogonal (\(\Pi^T \Pi = I\))
    8. -
    9. \(M\) (S_mat) has i.i.d. \(\mathcal{N}(0,1)\) entries
    10. +
    11. S_mat has i.i.d. \(\mathcal{N}(0,1)\) entries
    12. +
    13. gamma has shape \([S]\) (one \(\ell_2\) norm per key vector, float32)
    14. qjl_signs (\(\sigma\)) values are in \(\{-1, +1\}\) (int8)
    15. K_idx values are in \([0, C)\) (uint8)
    16. All floating-point inputs are float32