From 87064b3a5efbbbd4baad03f826b52aae65c09554 Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Mon, 22 Jun 2026 15:17:46 +0000 Subject: [PATCH 1/2] add unit tests to ensure kv cache consistency Signed-off-by: Zhang Jian --- rl_engine/testing/kv_consistency.py | 417 ++++++++++++++++++++++++++++ tests/test_kv_cache_consistency.py | 184 ++++++++++++ 2 files changed, 601 insertions(+) create mode 100644 rl_engine/testing/kv_consistency.py create mode 100644 tests/test_kv_cache_consistency.py diff --git a/rl_engine/testing/kv_consistency.py b/rl_engine/testing/kv_consistency.py new file mode 100644 index 0000000..9ad078a --- /dev/null +++ b/rl_engine/testing/kv_consistency.py @@ -0,0 +1,417 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +""" +Prefill / decode KV-cache path-consistency harness (WS1, issue #152). + +Rollout generates token-by-token through a *decode* path (one query against the +cached KV of the preceding tokens). Training re-scores the same sequence through +a *prefill* path (the whole sequence in one forward). If the two paths reduce in +a different order the same token gets a different logprob in rollout vs training +-- a high-impact rollout-vs-training drift source. + +This module defines a single fixed reduction-order contract and drives both +paths through it, so prefill and decode are bitwise-identical by construction. +A naive batched SDPA path is also provided to demonstrate the ~1e-7 drift that +appears when the reduction order is *not* shared -- i.e. the bug this guards +against. + +The core contract: attention for query position ``t`` is always computed by +:func:`attend_single_query` over keys ``0..t`` in ascending index order. Prefill +loops that op over every position; decode calls the same op once per generated +token, reading the keys back from a :class:`KVCache`. Same op + same inputs + +same order => bitwise-equal output. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + + +@dataclass(frozen=True) +class AttentionSpec: + """Shapes for a single attention layer (FlashAttention head layout).""" + + num_heads: int + num_kv_heads: int + head_dim: int + causal: bool = True + + def __post_init__(self) -> None: + if self.num_heads <= 0 or self.num_kv_heads <= 0 or self.head_dim <= 0: + raise ValueError("num_heads, num_kv_heads, head_dim must be positive") + if self.num_heads % self.num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads (GQA/MQA)") + + @property + def scale(self) -> float: + return 1.0 / math.sqrt(self.head_dim) + + @property + def gqa_group(self) -> int: + return self.num_heads // self.num_kv_heads + + +def expand_kv_heads(x: torch.Tensor, spec: AttentionSpec) -> torch.Tensor: + """Expand ``[..., Hkv, D]`` to ``[..., H, D]`` for GQA/MQA (ascending repeat).""" + + if spec.num_kv_heads == spec.num_heads: + return x + return x.repeat_interleave(spec.gqa_group, dim=-2) + + +def attend_single_query( + q_t: torch.Tensor, + k_ctx: torch.Tensor, + v_ctx: torch.Tensor, + *, + scale: float, + spec: AttentionSpec, + key_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Canonical fixed-order attention of one query against a key/value context. + + This is the single op shared by prefill and decode -- the reduction-order + contract lives here and nowhere else. + + Args: + q_t: ``[B, H, D]`` query for one position. + k_ctx, v_ctx: ``[B, T, Hkv, D]`` cached context (keys ``0..T-1``). + scale: softmax scale (``1/sqrt(D)``). + spec: attention shapes (drives GQA expansion). + key_mask: optional ``[B, T]`` bool, ``True`` = valid key. Invalid keys + are scored ``-inf`` (their post-softmax weight is exactly 0, so they + do not perturb the ascending-order sum -- adding ``0.0`` is an IEEE + identity regardless of reduction grouping). + + Returns: + ``[B, H, D]`` attention output in fp32. + """ + + q = q_t.float() + k = expand_kv_heads(k_ctx.float(), spec) # [B, T, H, D] + v = expand_kv_heads(v_ctx.float(), spec) + + # scores[b, h, t] = scale * sum_d q[b,h,d] * k[b,t,h,d] + scores = torch.einsum("bhd,bthd->bht", q, k) * scale + if key_mask is not None: + invalid = ~key_mask.to(device=scores.device, dtype=torch.bool) # [B, T] + scores = scores.masked_fill(invalid.unsqueeze(1), float("-inf")) + + weights = torch.softmax(scores, dim=-1) # over T, ascending key order + # Guard fully-masked rows (no valid key) -> zero output instead of NaN. + weights = torch.nan_to_num(weights, nan=0.0) + out = torch.einsum("bht,bthd->bhd", weights, v) # [B, H, D] + return out + + +class KVCache: + """Pre-allocated paged-free KV buffer; writer dtype is the stored dtype.""" + + def __init__( + self, + batch: int, + spec: AttentionSpec, + max_len: int, + *, + dtype: torch.dtype, + device: torch.device, + ): + self._spec = spec + self._len = 0 + self.key = torch.zeros( + (batch, max_len, spec.num_kv_heads, spec.head_dim), dtype=dtype, device=device + ) + self.value = torch.zeros_like(self.key) + + @property + def length(self) -> int: + return self._len + + def append(self, k_t: torch.Tensor, v_t: torch.Tensor) -> None: + """Store one timestep ``[B, Hkv, D]`` (cast to the cache's stored dtype).""" + + t = self._len + self.key[:, t] = k_t.to(self.key.dtype) + self.value[:, t] = v_t.to(self.value.dtype) + self._len += 1 + + def context(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.key[:, : self._len], self.value[:, : self._len] + + +def fixed_order_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + spec: AttentionSpec, + key_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Prefill via the contract: loop every query position through the single-query + op against keys ``0..t``. This is the golden reduction order. + + Args: + q: ``[B, S, H, D]``; k, v: ``[B, S, Hkv, D]``. + key_mask: optional ``[B, S]`` bool of valid (non-pad) positions. + + Returns: + ``[B, S, H, D]`` fp32 attention output. + """ + + _check_qkv(q, k, v, spec) + b, s = q.shape[0], q.shape[1] + out = torch.zeros((b, s, spec.num_heads, spec.head_dim), dtype=torch.float32, device=q.device) + for t in range(s): + ctx_end = t + 1 if spec.causal else s + mask_t = key_mask[:, :ctx_end] if key_mask is not None else None + out[:, t] = attend_single_query( + q[:, t], + k[:, :ctx_end], + v[:, :ctx_end], + scale=spec.scale, + spec=spec, + key_mask=mask_t, + ) + return out + + +def replay_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + spec: AttentionSpec, + key_mask: Optional[torch.Tensor] = None, + kv_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Decode via the contract: step token-by-token, append each ``k/v`` to a + :class:`KVCache`, attend the single query against the cache. + + With ``kv_dtype`` equal to ``q.dtype`` (default) the stored context is a + bitwise copy of the prefill keys, so the output equals + :func:`fixed_order_attention` bitwise. ``kv_dtype`` lets PR3 exercise + lower-precision stored KV (writer-vs-reader drift). + """ + + _check_qkv(q, k, v, spec) + if not spec.causal: + raise ValueError("decode replay requires a causal spec") + b, s = q.shape[0], q.shape[1] + cache = KVCache( + b, spec, s, dtype=kv_dtype or q.dtype, device=q.device + ) + out = torch.zeros((b, s, spec.num_heads, spec.head_dim), dtype=torch.float32, device=q.device) + for t in range(s): + cache.append(k[:, t], v[:, t]) + k_ctx, v_ctx = cache.context() + mask_t = key_mask[:, : t + 1] if key_mask is not None else None + out[:, t] = attend_single_query( + q[:, t], k_ctx, v_ctx, scale=spec.scale, spec=spec, key_mask=mask_t + ) + return out + + +# --------------------------------------------------------------------------- # +# Parity assertion helper (aligns with the #108 tolerance conventions). +# --------------------------------------------------------------------------- # + + +@dataclass(frozen=True) +class ParityReport: + bitwise: bool + max_abs_error: float + mean_abs_error: float + + +def parity_report(candidate: torch.Tensor, reference: torch.Tensor) -> ParityReport: + if candidate.shape != reference.shape: + raise ValueError( + f"shape mismatch: {tuple(candidate.shape)} vs {tuple(reference.shape)}" + ) + bitwise = bool(torch.equal(candidate, reference)) + diff = (candidate.float() - reference.float()).abs() + return ParityReport( + bitwise=bitwise, + max_abs_error=float(diff.max().item()) if diff.numel() else 0.0, + mean_abs_error=float(diff.mean().item()) if diff.numel() else 0.0, + ) + + +def assert_path_parity( + candidate: torch.Tensor, + reference: torch.Tensor, + *, + require_bitwise: bool = False, + atol: float = 1e-5, + rtol: float = 1e-5, + msg: str = "", +) -> ParityReport: + """Assert two paths agree; bitwise when required, else within tolerance.""" + + report = parity_report(candidate, reference) + prefix = f"{msg}: " if msg else "" + if require_bitwise: + assert report.bitwise, ( + f"{prefix}expected bitwise-equal paths but max_abs_error=" + f"{report.max_abs_error:.3e} (mean={report.mean_abs_error:.3e})" + ) + else: + torch.testing.assert_close( + candidate.float(), + reference.float(), + atol=atol, + rtol=rtol, + msg=lambda m: f"{prefix}{m}", + ) + return report + + +# --------------------------------------------------------------------------- # +# Tiny end-to-end causal LM for logprob-level (generate vs re-score) checks. +# --------------------------------------------------------------------------- # + + +class TinyCausalLM(nn.Module): + """ + Minimal single-layer causal LM sharing one attention contract across prefill + and decode. Deterministic init; fp32 throughout. Not a real model -- just + enough of a forward chain (embed -> qkv -> attention -> o_proj -> lm_head) + to produce logits/logprobs for path-parity tests. + """ + + def __init__( + self, + vocab_size: int, + d_model: int, + spec: AttentionSpec, + *, + seed: int = 0, + ): + super().__init__() + self.spec = spec + self.vocab_size = vocab_size + qdim = spec.num_heads * spec.head_dim + kvdim = spec.num_kv_heads * spec.head_dim + self.embed = nn.Embedding(vocab_size, d_model) + self.q_proj = nn.Linear(d_model, qdim, bias=False) + self.k_proj = nn.Linear(d_model, kvdim, bias=False) + self.v_proj = nn.Linear(d_model, kvdim, bias=False) + self.o_proj = nn.Linear(qdim, d_model, bias=False) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False) + + gen = torch.Generator().manual_seed(seed) + for param in self.parameters(): + with torch.no_grad(): + param.copy_(torch.empty_like(param).normal_(0.0, 0.02, generator=gen)) + self.eval() + + def _project( + self, input_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + b, s = input_ids.shape + spec = self.spec + h = self.embed(input_ids) + q = self.q_proj(h).view(b, s, spec.num_heads, spec.head_dim) + k = self.k_proj(h).view(b, s, spec.num_kv_heads, spec.head_dim) + v = self.v_proj(h).view(b, s, spec.num_kv_heads, spec.head_dim) + return q, k, v + + def _to_logits(self, attn_out: torch.Tensor) -> torch.Tensor: + b, s = attn_out.shape[0], attn_out.shape[1] + merged = attn_out.reshape(b, s, self.spec.num_heads * self.spec.head_dim) + return self.lm_head(self.o_proj(merged)) + + @torch.no_grad() + def prefill_logits( + self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + q, k, v = self._project(input_ids) + attn = fixed_order_attention(q, k, v, spec=self.spec, key_mask=attention_mask) + return self._to_logits(attn) + + @torch.no_grad() + def decode_logits( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + *, + kv_dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + q, k, v = self._project(input_ids) + attn = replay_decode( + q, k, v, spec=self.spec, key_mask=attention_mask, kv_dtype=kv_dtype + ) + return self._to_logits(attn) + + @torch.no_grad() + def generate( + self, prompt_ids: torch.Tensor, max_new_tokens: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Greedy decode against a live KV cache. + + Returns ``(full_ids, step_logprobs)`` where ``full_ids`` is + ``[B, prompt + max_new_tokens]`` and ``step_logprobs[:, i]`` is the + logprob the decode path assigned to the token chosen at generation + step ``i`` (token at full position ``prompt + i``). + """ + + spec = self.spec + b, prompt_len = prompt_ids.shape + total = prompt_len + max_new_tokens + cache = KVCache(b, spec, total, dtype=torch.float32, device=prompt_ids.device) + + def step(token_col: torch.Tensor) -> torch.Tensor: + # token_col: [B] -> logits [B, V] for the next position. + h = self.embed(token_col) # [B, d_model] + q = self.q_proj(h).view(b, spec.num_heads, spec.head_dim) + k = self.k_proj(h).view(b, spec.num_kv_heads, spec.head_dim) + vv = self.v_proj(h).view(b, spec.num_kv_heads, spec.head_dim) + cache.append(k, vv) + k_ctx, v_ctx = cache.context() + attn = attend_single_query(q, k_ctx, v_ctx, scale=spec.scale, spec=spec) + merged = attn.reshape(b, spec.num_heads * spec.head_dim) + return self.lm_head(self.o_proj(merged)) + + # Consume the prompt; keep the logits produced at the final prompt token. + logits = None + for t in range(prompt_len): + logits = step(prompt_ids[:, t]) + + gen_ids = [] + step_logprobs = [] + for _ in range(max_new_tokens): + logprobs = torch.log_softmax(logits.float(), dim=-1) + next_token = torch.argmax(logprobs, dim=-1) # [B] + step_logprobs.append(logprobs.gather(1, next_token.unsqueeze(1)).squeeze(1)) + gen_ids.append(next_token) + logits = step(next_token) + + generated = torch.stack(gen_ids, dim=1) # [B, max_new_tokens] + full_ids = torch.cat([prompt_ids, generated], dim=1) + return full_ids, torch.stack(step_logprobs, dim=1) + + +# --------------------------------------------------------------------------- # +# internal +# --------------------------------------------------------------------------- # + + +def _check_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, spec: AttentionSpec) -> None: + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError("q, k, v must be [B, S, H, D]") + if q.shape[2] != spec.num_heads or k.shape[2] != spec.num_kv_heads: + raise ValueError("head counts must match the AttentionSpec") + if k.shape != v.shape: + raise ValueError("k and v must share shape") + if q.shape[0] != k.shape[0] or q.shape[1] != k.shape[1]: + raise ValueError("q and k/v must share batch and sequence length") diff --git a/tests/test_kv_cache_consistency.py b/tests/test_kv_cache_consistency.py new file mode 100644 index 0000000..3eac79b --- /dev/null +++ b/tests/test_kv_cache_consistency.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +""" +Prefill / decode KV-cache path-consistency tests (WS1, issue #152). + +One representative test per claim across the five planned PRs: + + 1. Full-prefill == chunked-prefill, and naive batched SDPA is NOT bitwise. + 2. Decode (one query vs N cached KV) == prefill (incl. padding, GQA). + 3. Stored-KV: matched dtype is bitwise; low precision stays bounded. + 4. "Generate then re-score" equivalence. + 5. CI smoke across short / long / varlen / padded. + +Contract: prefill and decode share one fixed reduction order, so matched dtype +is bitwise-equal. CPU/fp32 -- no GPU/kernel build required. +""" + +from __future__ import annotations + +import pytest +import torch + +import torch.nn.functional as F + +from rl_engine.testing.kv_consistency import ( + AttentionSpec, + KVCache, + TinyCausalLM, + assert_path_parity, + attend_single_query, + expand_kv_heads, + fixed_order_attention, + parity_report, + replay_decode, +) +from rl_engine.testing.reference_ops import selected_logprobs_reference + +torch.manual_seed(0) + +# One GQA spec is the representative case: it exercises the GQA head expansion +# (num_heads > num_kv_heads) that dense/MQA are degenerate cases of. +SPEC = AttentionSpec(num_heads=8, num_kv_heads=2, head_dim=16) + + +def _make_qkv(batch, seqlen, spec, *, seed=0): + gen = torch.Generator().manual_seed(seed) + q = torch.randn(batch, seqlen, spec.num_heads, spec.head_dim, generator=gen) + k = torch.randn(batch, seqlen, spec.num_kv_heads, spec.head_dim, generator=gen) + v = torch.randn(batch, seqlen, spec.num_kv_heads, spec.head_dim, generator=gen) + return q, k, v + + +# --------------------------------------------------------------------------- # +# PR1 -- Full-prefill vs chunked-prefill consistency. +# --------------------------------------------------------------------------- # + + +def test_full_vs_chunked_prefill_bitwise(): + """Carrying KV across chunk boundaries must not change the reduction.""" + q, k, v = _make_qkv(2, 24, SPEC) + full = fixed_order_attention(q, k, v, spec=SPEC) + + b, s = q.shape[0], q.shape[1] + cache = KVCache(b, SPEC, s, dtype=q.dtype, device=q.device) + chunked = torch.zeros_like(full) + for start in range(0, s, 8): # chunk size 8 + for t in range(start, min(start + 8, s)): + cache.append(k[:, t], v[:, t]) + kc, vc = cache.context() + chunked[:, t] = attend_single_query(q[:, t], kc, vc, scale=SPEC.scale, spec=SPEC) + assert_path_parity(chunked, full, require_bitwise=True) + + +def _naive_batched_sdpa(q, k, v, spec): + """A naive whole-sequence SDPA -- the path whose reduction order we must NOT use.""" + qh = q.transpose(1, 2).float() # [B, H, S, D] + kh = expand_kv_heads(k, spec).transpose(1, 2).float() + vh = expand_kv_heads(v, spec).transpose(1, 2).float() + out = F.scaled_dot_product_attention(qh, kh, vh, is_causal=spec.causal, scale=spec.scale) + return out.transpose(1, 2) + + +def test_naive_sdpa_only_matches_within_tolerance(): + """The naive batched path drifts from the contract -- the bug #152 targets.""" + q, k, v = _make_qkv(2, 48, SPEC) + contract = fixed_order_attention(q, k, v, spec=SPEC) + report = parity_report(_naive_batched_sdpa(q, k, v, SPEC), contract) + assert not report.bitwise # close, but NOT bitwise: reduction-order drift + assert report.max_abs_error < 1e-4 + + +# --------------------------------------------------------------------------- # +# PR2 -- Decode (one query vs N cached KV) vs prefill. +# --------------------------------------------------------------------------- # + + +def test_decode_matches_prefill_bitwise(): + q, k, v = _make_qkv(4, 32, SPEC) # batch > 1 + prefill = fixed_order_attention(q, k, v, spec=SPEC) + decode = replay_decode(q, k, v, spec=SPEC) + assert_path_parity(decode, prefill, require_bitwise=True) + + +@pytest.mark.parametrize("pad_side", ["left", "right"]) +def test_decode_matches_prefill_with_padding(pad_side): + batch, seqlen = 3, 20 + q, k, v = _make_qkv(batch, seqlen, SPEC) + lengths = [20, 14, 8] + mask = torch.zeros(batch, seqlen, dtype=torch.bool) + for b, L in enumerate(lengths): + if pad_side == "right": + mask[b, :L] = True + else: + mask[b, seqlen - L :] = True + prefill = fixed_order_attention(q, k, v, spec=SPEC, key_mask=mask) + decode = replay_decode(q, k, v, spec=SPEC, key_mask=mask) + assert_path_parity(decode, prefill, require_bitwise=True, msg=f"pad={pad_side}") + + +# --------------------------------------------------------------------------- # +# PR3 -- Stored-KV layout/dtype: no writer-vs-reader precision drift. +# --------------------------------------------------------------------------- # + + +def test_stored_kv_matched_dtype_is_bitwise(): + """fp32 writer + fp32 reader -> the cache itself adds zero drift.""" + q, k, v = _make_qkv(2, 24, SPEC) + prefill = fixed_order_attention(q, k, v, spec=SPEC) + decode = replay_decode(q, k, v, spec=SPEC, kv_dtype=torch.float32) + assert_path_parity(decode, prefill, require_bitwise=True) + + +def test_stored_kv_low_precision_within_tolerance(): + """Low-precision storage is the ONLY drift source, and it is bounded.""" + q, k, v = _make_qkv(2, 32, SPEC) + prefill = fixed_order_attention(q, k, v, spec=SPEC) + decode = replay_decode(q, k, v, spec=SPEC, kv_dtype=torch.float16) + assert_path_parity(decode, prefill, atol=5e-3, rtol=5e-3) + + +# --------------------------------------------------------------------------- # +# PR4 -- "Generate then re-score" equivalence (full forward chain). +# --------------------------------------------------------------------------- # + + +def test_generate_then_rescore_equivalence(): + model = TinyCausalLM(vocab_size=64, d_model=48, spec=SPEC, seed=1) + prompt = torch.randint(0, 64, (2, 5)) + + full_ids, gen_step_logprobs = model.generate(prompt, max_new_tokens=7) + + # Re-score the produced sequence through the prefill (training) path. + prefill_logits = model.prefill_logits(full_ids) + rescored = selected_logprobs_reference(prefill_logits[:, :-1], full_ids[:, 1:]) + rescored_gen = rescored[:, prompt.shape[1] - 1 :] # generated positions only + assert_path_parity(rescored_gen, gen_step_logprobs, require_bitwise=True) + + +# --------------------------------------------------------------------------- # +# PR5 -- CI smoke test (short / long / varlen / padded). +# --------------------------------------------------------------------------- # + + +@pytest.mark.parametrize( + ("seqlen", "lengths"), + [ + pytest.param(4, None, id="short"), + pytest.param(256, None, id="long"), + pytest.param(32, [32, 17, 5, 1], id="varlen"), + pytest.param(24, [24, 24, 10, 3], id="padded"), + ], +) +def test_decode_smoke(seqlen, lengths): + batch = 4 if lengths else 2 + q, k, v = _make_qkv(batch, seqlen, SPEC, seed=7) + mask = None + if lengths: + mask = torch.zeros(batch, seqlen, dtype=torch.bool) + for b, L in enumerate(lengths): + mask[b, :L] = True + prefill = fixed_order_attention(q, k, v, spec=SPEC, key_mask=mask) + decode = replay_decode(q, k, v, spec=SPEC, key_mask=mask) + assert_path_parity(decode, prefill, require_bitwise=True) From 5d89687e81d1790e5e41cf2109b31e8f1f150120 Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Mon, 22 Jun 2026 16:54:40 +0000 Subject: [PATCH 2/2] wip: test cuda kernel performance Signed-off-by: Zhang Jian --- rl_engine/testing/kv_consistency.py | 8 +- tests/test_kv_cache_consistency.py | 41 ++++--- tests/test_kv_cache_consistency_cuda.py | 149 ++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 25 deletions(-) create mode 100644 tests/test_kv_cache_consistency_cuda.py diff --git a/rl_engine/testing/kv_consistency.py b/rl_engine/testing/kv_consistency.py index 9ad078a..d122280 100644 --- a/rl_engine/testing/kv_consistency.py +++ b/rl_engine/testing/kv_consistency.py @@ -2,7 +2,7 @@ # Copyright (c) 2026 RL-Kernel Contributors """ -Prefill / decode KV-cache path-consistency harness (WS1, issue #152). +Prefill / decode KV-cache path-consistency harness. Rollout generates token-by-token through a *decode* path (one query against the cached KV of the preceding tokens). Training re-scores the same sequence through @@ -198,8 +198,8 @@ def replay_decode( With ``kv_dtype`` equal to ``q.dtype`` (default) the stored context is a bitwise copy of the prefill keys, so the output equals - :func:`fixed_order_attention` bitwise. ``kv_dtype`` lets PR3 exercise - lower-precision stored KV (writer-vs-reader drift). + :func:`fixed_order_attention` bitwise. A lower-precision ``kv_dtype`` + exercises writer-vs-reader drift introduced solely by the stored KV. """ _check_qkv(q, k, v, spec) @@ -221,7 +221,7 @@ def replay_decode( # --------------------------------------------------------------------------- # -# Parity assertion helper (aligns with the #108 tolerance conventions). +# Parity assertion helper (bitwise where required, else within tolerance). # --------------------------------------------------------------------------- # diff --git a/tests/test_kv_cache_consistency.py b/tests/test_kv_cache_consistency.py index 3eac79b..7679aab 100644 --- a/tests/test_kv_cache_consistency.py +++ b/tests/test_kv_cache_consistency.py @@ -2,25 +2,20 @@ # Copyright (c) 2026 RL-Kernel Contributors """ -Prefill / decode KV-cache path-consistency tests (WS1, issue #152). - -One representative test per claim across the five planned PRs: - - 1. Full-prefill == chunked-prefill, and naive batched SDPA is NOT bitwise. - 2. Decode (one query vs N cached KV) == prefill (incl. padding, GQA). - 3. Stored-KV: matched dtype is bitwise; low precision stays bounded. - 4. "Generate then re-score" equivalence. - 5. CI smoke across short / long / varlen / padded. - -Contract: prefill and decode share one fixed reduction order, so matched dtype -is bitwise-equal. CPU/fp32 -- no GPU/kernel build required. +Prefill / decode KV-cache path-consistency tests. + +Prefill (whole-sequence re-scoring) and decode (one query against the cached KV +of preceding tokens) share a single fixed reduction order, so for matched dtype +they produce bitwise-identical logits and logprobs. These tests assert that +equivalence across chunked prefill, padded and variable-length sequences, stored +KV dtypes, and an end-to-end generate-then-rescore round trip. They run on CPU in +fp32 and require no GPU or compiled kernels. """ from __future__ import annotations import pytest import torch - import torch.nn.functional as F from rl_engine.testing.kv_consistency import ( @@ -52,7 +47,7 @@ def _make_qkv(batch, seqlen, spec, *, seed=0): # --------------------------------------------------------------------------- # -# PR1 -- Full-prefill vs chunked-prefill consistency. +# Prefill reduction order # --------------------------------------------------------------------------- # @@ -81,17 +76,21 @@ def _naive_batched_sdpa(q, k, v, spec): return out.transpose(1, 2) -def test_naive_sdpa_only_matches_within_tolerance(): - """The naive batched path drifts from the contract -- the bug #152 targets.""" +def test_naive_sdpa_diverges_from_fixed_order(): + """A whole-sequence SDPA reduces in a different order and is not bitwise-equal. + + This guards the contract: it confirms the bitwise guarantee is meaningful + rather than vacuously true, while staying within close numerical tolerance. + """ q, k, v = _make_qkv(2, 48, SPEC) contract = fixed_order_attention(q, k, v, spec=SPEC) report = parity_report(_naive_batched_sdpa(q, k, v, SPEC), contract) - assert not report.bitwise # close, but NOT bitwise: reduction-order drift + assert not report.bitwise assert report.max_abs_error < 1e-4 # --------------------------------------------------------------------------- # -# PR2 -- Decode (one query vs N cached KV) vs prefill. +# Decode vs prefill parity # --------------------------------------------------------------------------- # @@ -119,7 +118,7 @@ def test_decode_matches_prefill_with_padding(pad_side): # --------------------------------------------------------------------------- # -# PR3 -- Stored-KV layout/dtype: no writer-vs-reader precision drift. +# Stored-KV dtype # --------------------------------------------------------------------------- # @@ -140,7 +139,7 @@ def test_stored_kv_low_precision_within_tolerance(): # --------------------------------------------------------------------------- # -# PR4 -- "Generate then re-score" equivalence (full forward chain). +# Generate then re-score # --------------------------------------------------------------------------- # @@ -158,7 +157,7 @@ def test_generate_then_rescore_equivalence(): # --------------------------------------------------------------------------- # -# PR5 -- CI smoke test (short / long / varlen / padded). +# Decode smoke coverage # --------------------------------------------------------------------------- # diff --git a/tests/test_kv_cache_consistency_cuda.py b/tests/test_kv_cache_consistency_cuda.py new file mode 100644 index 0000000..840db96 --- /dev/null +++ b/tests/test_kv_cache_consistency_cuda.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +""" +Prefill / decode parity for the real CUDA attention kernel. + +The CPU suite (``test_kv_cache_consistency.py``) validates the fixed reduction-order +*contract* with a reference implementation. This suite exercises the production +FlashAttention kernel itself: it checks that the value the kernel produces for query +position ``t`` is the same (within dtype tolerance) whether ``t`` is computed as part +of a full causal prefill or as a single-query decode step against the cached prefix. + +The repository exposes only a full-sequence attention kernel; a decode step is driven +by calling that kernel with ``seqlen_q == 1`` against the keys/values seen so far. Flash +kernels reduce in a hardware-dependent order, so parity here is tolerance-based rather +than bitwise. The whole module skips cleanly when no CUDA FlashAttention kernel is +available. +""" + +from __future__ import annotations + +import pytest +import torch + +AVAILABILITY_ERRORS = (ImportError, ModuleNotFoundError, OSError, RuntimeError) + +DTYPE_CASES = [ + pytest.param(torch.float16, 1e-3, 1e-3, id="fp16"), + pytest.param(torch.bfloat16, 2e-2, 2e-2, id="bf16"), +] + +# (batch, seqlen, nheads, nheads_k, head_dim) -- includes a GQA case (nheads_k < nheads). +SHAPE_CASES = [ + pytest.param(1, 96, 4, 4, 64, id="b1-s96-mha-d64"), + pytest.param(2, 64, 8, 2, 32, id="b2-s64-gqa-d32"), +] + + +def cuda_flash_attention_availability(): + if not torch.cuda.is_available(): + return False, "CUDA device is not available" + if torch.version.hip is not None: + return False, "current torch build is not the CUDA platform" + try: + from rl_engine.kernels.ops.cuda.attention.flash_attn import FlashAttentionOp + + FlashAttentionOp() + except AVAILABILITY_ERRORS as exc: + return False, f"CUDA FlashAttentionOp is unavailable: {exc}" + return True, "" + + +def _make_qkv(batch, seqlen, nheads, nheads_k, head_dim, dtype): + gen = torch.Generator(device="cuda").manual_seed(0) + shape_q = (batch, seqlen, nheads, head_dim) + shape_kv = (batch, seqlen, nheads_k, head_dim) + q = torch.randn(shape_q, device="cuda", dtype=dtype, generator=gen) + k = torch.randn(shape_kv, device="cuda", dtype=dtype, generator=gen) + v = torch.randn(shape_kv, device="cuda", dtype=dtype, generator=gen) + return q, k, v + + +def _decode_replay(op, q, k, v, softmax_scale): + """Drive a decode path with the real kernel: one query vs the cached prefix.""" + seqlen = q.shape[1] + steps = [] + for t in range(seqlen): + out_t = op( + q[:, t : t + 1].contiguous(), + k[:, : t + 1].contiguous(), + v[:, : t + 1].contiguous(), + softmax_scale=softmax_scale, + causal=False, # a single query legitimately attends to all cached keys 0..t + ) + steps.append(out_t[:, 0]) + return torch.stack(steps, dim=1) + + +@pytest.mark.parametrize(("dtype", "atol", "rtol"), DTYPE_CASES) +@pytest.mark.parametrize(("batch", "seqlen", "nheads", "nheads_k", "head_dim"), SHAPE_CASES) +def test_decode_matches_prefill_cuda(dtype, atol, rtol, batch, seqlen, nheads, nheads_k, head_dim): + available, reason = cuda_flash_attention_availability() + if not available: + pytest.skip(reason) + + from rl_engine.kernels.ops.cuda.attention.flash_attn import FlashAttentionOp + + op = FlashAttentionOp() + q, k, v = _make_qkv(batch, seqlen, nheads, nheads_k, head_dim, dtype) + softmax_scale = 1.0 / head_dim**0.5 + + prefill = op(q, k, v, softmax_scale=softmax_scale, causal=True) + decode = _decode_replay(op, q, k, v, softmax_scale) + + torch.testing.assert_close(decode.float(), prefill.float(), atol=atol, rtol=rtol) + + +# --------------------------------------------------------------------------- # +# Triton FlashAttention: prefill/decode consistency at block-aligned positions. +# +# Unlike FlashAttentionOp, the Triton kernel runs without the compiled _C +# extension, so this executes on any CUDA box. The kernel requires seqlen_q == +# seqlen_k and a block-aligned length, so true per-token decode is not +# expressible; instead we assert the equivalent invariance: the output at a +# block-aligned position must not depend on tokens that come after it, i.e. a +# length-L causal prefill and a longer length-S prefill agree at position L-1. +# This is bitwise on real hardware (verified fp16/bf16, Blackwell). +# --------------------------------------------------------------------------- # + + +def triton_attention_availability(): + if not torch.cuda.is_available(): + return False, "CUDA device is not available" + try: + from rl_engine.kernels.ops.triton.triton_attn import triton_flash_attention # noqa: F401 + except AVAILABILITY_ERRORS as exc: + return False, f"Triton FlashAttention is unavailable: {exc}" + return True, "" + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_triton_prefill_position_invariant_to_future_tokens(dtype): + available, reason = triton_attention_availability() + if not available: + pytest.skip(reason) + + from rl_engine.kernels.ops.triton.triton_attn import triton_flash_attention + + batch, heads, head_dim, block = 2, 4, 64, 64 + total = 4 * block + gen = torch.Generator(device="cuda").manual_seed(0) + + def randn(seq): + return torch.randn(batch, heads, seq, head_dim, device="cuda", dtype=dtype, generator=gen) + + q, k, v = randn(total), randn(total), randn(total) + full = triton_flash_attention(q, k, v, causal=True) + + for length in range(block, total + 1, block): # block-aligned prefixes + prefix = triton_flash_attention( + q[:, :, :length].contiguous(), + k[:, :, :length].contiguous(), + v[:, :, :length].contiguous(), + causal=True, + ) + boundary = length - 1 + assert torch.equal(prefix[:, :, boundary], full[:, :, boundary]), ( + f"position {boundary} changed when future tokens were added (dtype={dtype})" + )