Skip to content

【Hackathon 10th Spring No.53】[Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 PR1 [cf]#7717

Open
bob-cloudforge wants to merge 8 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-053-pr1-headwise-swa-v4
Open

【Hackathon 10th Spring No.53】[Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 PR1 [cf]#7717
bob-cloudforge wants to merge 8 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-053-pr1-headwise-swa-v4

Conversation

@bob-cloudforge
Copy link
Copy Markdown

@bob-cloudforge bob-cloudforge commented May 4, 2026

Motivation

Hackathon 10th Spring Task No.53 — 离散 KV Cache 管理和 AppendAttention 算子的性能优化 (PR1 of 2). Spec: https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_10th/【Hackathon_10th】开源贡献个人挑战赛春节特别季—任务合集.md#no53.

For models that mix Sliding-Window Attention (SWA) heads with full-attention heads inside the same layer, today's V1 KV-cache scheduling path (ResourceManagerV1 + PrefixCacheManager, gated by the default-on ENABLE_V1_KVCACHE_SCHEDULER=1) allocates one shared block_idx per layer for all heads. SWA heads finish their window long before full-attn heads, but their cache stays pinned until the whole layer evicts. Throughput suffers.

This PR teaches the V1 scheduler + PrefixCacheManager to manage block_idx per head (head-wise SWA layout) and recycle a SWA head's cache as soon as it crosses its window — the per-head equivalent of what PR #6702 did for V0.

Authorship: this PR is independently designed and implemented by the submitter for Hackathon 10th Spring No.53. The earlier community PR #6702 (V0, not merged) is referenced as prior art only; no code is lifted unattributed. Any future contributor work will be acknowledged via per-commit Co-authored-by trailers.

RFC: PaddlePaddle/community#1364.

Modifications

Area Change
fastdeploy/cache_manager/prefix_cache_manager.py Per-request head-wise GPU free list (gpu_free_block_list_head_wise[head]); allocate_gpu_blocks_head_wise / recycle_gpu_blocks_head_wise; TP-aware sizing (num_key_value_heads // tp_size)
fastdeploy/engine/sched/resource_manager_v1.py recycle_request_swa_head_cache (per-head cursor advance ≥ window+sink); _should_skip_swa_recycle_for_overlap (per-request cache_swap_metadata / cache_evict_metadata inspection); P4 cleanup in _free_blocks
fastdeploy/model_executor/models/paddleformers/base.py Default-off ERNIE SWA fixture (window/sink/skip-freq/ratio) gated by FD_T53_HEAD_WISE_SWA_FIXTURE=1
fastdeploy/config.py +20 — Engine-main FDConfig fixture: mirror the paddleformers/base.py head-wise SWA attribute injection so ResourceManagerV1._should_use_head_wise_swa (engine-main) sees the same model_config.head_wise_swa_ratio as the worker. Gated on FD_T53_HEAD_WISE_SWA_FIXTURE.
Mutual exclusion enable_prefix_caching=True + FD_HEAD_WISE_KV_CACHE=1 raises at PrefixCacheManager.__init__
Env gates FD_HEAD_WISE_KV_CACHE=0 default — bit-identical when disabled

Tests use real lightweight objects + object.__new__/AST or shape oracles (no MagicMock-only). PR2, not PR1, owns kernel-visible block_tables_headwise / FP8 scale-layout changes.

PR2 (separate) lands the AppendAttention rank-2 block_tables_headwise ABI + ForwardMeta wiring + kv_num_heads field as a frozen-shape parameter; PR1 keeps share_inputs.block_tables 2D and reaches the +30% recycle gate via cache-manager-side changes only.

Usage or Command

# Enable head-wise V1 cache + timely SWA recycle.
# All four env vars must be set together — partial activation is silently a no-op.
# Without FD_T53_HEAD_WISE_SWA_FIXTURE=1, the engine-main gate stays dormant
# (no model config publishes head_wise_swa_ratio) and head-wise alloc/recycle never fires
# — verified by the wrapper oracle in bench_recycle.sh.
export FD_T53_HEAD_WISE_SWA_FIXTURE=1     # engine-main FDConfig fixture (config.py)
export ENABLE_V1_KVCACHE_SCHEDULER=1      # default; shown for clarity
export FD_HEAD_WISE_KV_CACHE=1            # enables per-head block tables
export FD_T53_HEAD_WISE_SWA_RATIO=1.0     # SWA recycle ratio (>0 = recycle active)
python -m fastdeploy.entrypoints.openai.api_server \
    --model baidu/ERNIE-4.5-21B-A3B-Paddle \
    --max-model-len 32768

Accuracy Tests

Spec PR1 acceptancethroughput up ≥30% with timely SWA recycle vs without, same VRAM, fixed-IO dataset, V1 KV-cache scheduler on (ENABLE_V1_KVCACHE_SCHEDULER=1, default):

Round 2 (gate run — 128 prompts):

Config Hardware Output throughput (tok/s) Δ
head-wise + recycle OFF A800-80GB 706.29 baseline
head-wise + recycle ON A800-80GB 1107.98 +56.9% ≥30 ✓

Round 3 (full run — 1024 prompts):

Config Hardware Output throughput (tok/s) Δ
head-wise + recycle OFF A800-80GB 722.93 baseline
head-wise + recycle ON A800-80GB 1270.87 +75.8% ≥30 ✓

Round 3 integrity: completed=1024/1024 both arms, errors=0, mean TTFT improved -48.0% (2,708 s → 1,407 s).

Benchmark: FastDeploy/benchmarks/benchmark_serving.py — random fixed-IO dataset, input≈10.6k tokens avg / output≈4k tokens avg, request-rate=8, seed=42, --ignore-eos, server --max-concurrency=8192, YAML eb45-21b-a3b-32k-bf16-kv50-512s.yaml (kv_cache_max_ratio=0.50, max_seq_len=512). Fixed-IO integrity: both arms produce identical total_input_tokens=1,356,656 / total_output_tokens=518,946 for the 128-prompt gate run. Round 2 harness gate: completed=128, nonempty_errors=0. Round 3 target: completed=1024.

Hardware note for reviewers: spec does not pin PR1 hardware. Numbers above are A800-80GB (SM80) via Baidu AI Studio. If H/B card access is granted (cc @luotao1), we will append H/B numbers as supplementary evidence. PR2 (5% TTFT/TBT) does require H/B per spec; tracked separately.

Correctness:

  • CPU pytest coverage under tests/cache_manager/test_head_wise_*.py, tests/cache_manager/test_swa_recycle*.py, and tests/layers/test_append_attention_head_wise_shapes.py — real _FakeCacheManager + object.__new__(ResourceManagerV1) + AST/shape oracles. No MagicMock-only tests.
  • A800 smoke (bsz=4, seq=1024) + long-context recycle smoke — TBD, pending CI access
  • GSM8K parity (head-wise vs non-head-wise abs diff ≤ 0.5 pp) — TBD, deferred to follow-up validation pass

CI run: https://github.com/PaddlePaddle/FastDeploy/pull/7717/checks

Companion PR: #7718 (AppendAttention rank-2 head-wise block_idx kernel optimisation)

Checklist

  • pre-commit run --all-files clean
  • All CI checks green (Coverage / base_tests / codestyle / iluvatar / xpu)
  • Reviewer-requested changes addressed
  • No prohibited claims in PR body (verified by pre-push grep): "first in framework", "novel research", "unique to FastDeploy"
  • Authorship statement accurate (no unattributed lifted code)
  • Hardware label on every benchmark number matches the actual card used

@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 4, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label May 4, 2026
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 4, 2026

CLA assistant check
All committers have signed the CLA.

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 4, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-07 04:13:08

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

所有 required 任务已通过(本 PR 无 required 任务配置);有 1 个可选任务失败,不阻塞合并;另有 7 个 Workflow 处于 action_required 状态等待人工审批。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
2(0) 2 1 1 0 0 0

⚠️ 注意:以下 7 个 Workflow 处于 action_required 状态(等待审批后才会执行):CI_XPU、PR Build and Test、Check PR Template、Codestyle-Check、ILUVATAR-CI、Approval、CI_HPU。这些 Workflow 需人工审批触发。

注意:action_required workflows 不计入上表的任务统计。


2 任务状态汇总

2.1 Required任务 : 0/0 通过

本 PR 无必选任务配置,不影响合并。

无必选任务。

2.2 可选任务 — 1/2 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Trigger Jenkins for PR 2m51s Job -
其余 1 个可选任务通过 - - -

3 失败详情(仅 required)

无 required 失败任务。

@bob-cloudforge bob-cloudforge changed the title feat(kvcache): support head-wise SWA recycle 【Hackathon 10th Spring No.53】[Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 [cf] May 4, 2026
bob-cloudforge added a commit to CloudForge-Solutions/FastDeploy that referenced this pull request May 6, 2026
The PR1 head-wise allocator (PaddlePaddle#7717) emits flat global block IDs in
[0, num_gpu_blocks * kv_num_heads) from a single shared min-heap, but
the PR2 discrete kernel (PaddlePaddle#7718) ABI L1 expects per-head local IDs in
{-1} ∪ [0, num_gpu_blocks). This causes cudaIllegalAddress on any
request whose allocated IDs cross the num_gpu_blocks boundary
(i.e. immediately on head index ≥ ceil(num_gpu_blocks / num_blocks)).

This commit normalizes IDs at the backend boundary in append_attn_backend.py
using `local = flat % num_gpu_blocks` (sentinel -1 preserved), with a
fail-fast assert to catch any residual OOB. The hotfix is bench-only;
the canonical fix (per-head independent allocator pools) is deferred to
PR1 v5 (RFC-PR1-reanchored.md §3).

Also adds FD_T53_HEAD_WISE_SWA_RATIO ∈ [0.0, 1.0] validator.

Refs: .checkpoints/h10/task-53/design/PR2-HOTFIX-SPEC.md (Option B, OPUS-GATE PASS)
     .checkpoints/h10/task-53/design/CONTRACT-ORACLE.md (I2, I7)
     .checkpoints/h10/task-53/design/RFC-PR2-reanchored.md (ABI L1)

Files: 2 changed (1 backend hotfix, 1 envs validator)
…not cache-ids)

PaddlePaddle-bot flagged that _init_head_wise_free_list and the
allocate/recycle paths exported the raw length of
gpu_free_head_wise_block_list as free_gpu_block_num. That list holds
num_gpu_blocks * kv_num_heads per-(block,head) cache ids, so the metric
inflated by kv_num_heads (e.g. 8x for ERNIE-21B-A3B-Paddle).

Divide by max(1, kv_num_heads) at all three sites so the exported
counter stays in logical-block units, consistent with the legacy
gpu_free_block_list semantics that downstream dashboards rely on.

Refs: review on PR PaddlePaddle#7717 (PaddlePaddle-bot)
Signed-off-by: bob-cloudforge <bob@cloudforge.solutions>
…sink-safe)

PaddlePaddle-bot review on PR PaddlePaddle#7717 flagged the four 'if (block_id < 0)
{ block_id = 0; }' fallbacks in the c16 multiquery attention kernel as
potentially unsafe — accessing block 0 when block_id == -1 looks like a
silent OOB.

Document the actual contract: block_id == -1 is the SWA recycle sentinel
written by recycle_request_swa_head_cache (T53 PR1). The SWA mask built
from chunk_start/chunk_end zeroes any contribution from this aged-out
region in softmax, so the value loaded from block 0 is mathematically
masked away. SAFETY argument: when sink_size > 0, recycle_from_floor =
sink_blocks guarantees the sink window is never recycled, so block_id ==
-1 cannot occur inside the attended sink region.

This is a comment-only change. No code semantics altered.

Refs: review on PR PaddlePaddle#7717 (PaddlePaddle-bot)
Signed-off-by: bob-cloudforge <bob@cloudforge.solutions>
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

PR1 backport of PR2 commit 327a43b. Avoids integer-truncation
underestimating available KV blocks when head_free % kv_num_heads != 0,
which caused the scheduler to see 0 capacity on partial recycles and
trigger false OOM rejections.

Signed-off-by: bob-cloudforge <bob@cloudforge.solutions>
PaddlePaddle-bot

This comment was marked as outdated.

bob-cloudforge added a commit to CloudForge-Solutions/FastDeploy that referenced this pull request May 6, 2026
No behavior change in PR1 (singular is the populated heap here);
keeps the property body identical across PR1/PR2 so future merges
do not drift. Closes the PaddlePaddle-bot 🟡 advisory on PaddlePaddle#7717.
PaddlePaddle-bot

This comment was marked as outdated.

No behavior change in PR1 (singular is the populated heap here);
keeps the property body identical across PR1/PR2 so future merges
do not drift. Closes the PaddlePaddle-bot 🟡 advisory on PaddlePaddle#7717.
PaddlePaddle-bot

This comment was marked as outdated.

@bob-cloudforge bob-cloudforge changed the title 【Hackathon 10th Spring No.53】[Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 [cf] [Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 May 6, 2026
@bob-cloudforge bob-cloudforge changed the title [Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 【Hackathon 10th Spring No.53】[Feature][KVCache] Support head-wise SWA cache recycle in ResourceManagerV1 PR1 [cf] May 6, 2026
…lural)

getattr(self, 'gpu_free_head_wise_block_lists', None) always returned
None because the attribute is defined as gpu_free_head_wise_block_list
(singular). Also fix the calculation: the list is a single flat heap,
not a list-of-lists, so use len(free_list)//kv_num_heads not sum(len(h)).

Fixes P0 reported by PaddlePaddle-bot on PR PaddlePaddle#7717.
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

FIX 1: prefix_cache_manager.py L607 assert→RuntimeError for head-wise
  gpu free block underflow; raises at system boundary with context.

FIX 2: multiquery_attention_c16_impl.cuh — expand 4 SWA sentinel guard
  comments (Sites 1-4) to cover sink_size==0 case. All 4 sites now say
  'block_id >= 0 guaranteed for both sink_size > 0 and == 0' and cross-
  reference the top-of-loop comment that proves it.

FIX 3: tests/operators/test_head_wise_swa_sentinel_guard.py — pure-Python
  shadow oracle for the sentinel-guard contract. 3 tests, ZERO skip guards
  (T48/T49 gold standard). Verifies recycle positions never overlap the
  kernel's attended window for both sink_size>0 and sink_size==0.
Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-07 03:58:33

📋 Review 摘要

PR 概述:为 ResourceManagerV1 + PrefixCacheManager 新增 head-wise SWA KV Cache 回收路径(Hackathon 10th No.53 PR1),在 V1 调度器中实现逐头 SWA 块回收以提升 KV 利用率。
变更范围cache_manager/engine/sched/resource_manager_v1.pycustom_ops/gpu_ops/append_attn/config.pyworker/model_executor/models/
影响面 Tag[KVCache] [OP] [Scheduler] [FDConfig] [Models]


📝 PR 规范检查

PR 存在两处规范问题:①标题包含两个官方 Tag 且带非官方前缀/后缀②Checklist 使用了自定义条目,未遵循 §D2 模板结构,且全部未勾选

标题建议(可直接复制):

  • [Feature] Support head-wise SWA cache recycle in ResourceManagerV1 (PR1)

PR 描述建议(可直接复制,仅修复 Checklist 结构,Motivation/Modifications/Usage/AccuracyTests 内容保留原文):

## Motivation

Hackathon 10th Spring Task **No.53 — 离散 KV Cache 管理和 AppendAttention 算子的性能优化** (PR1 of 2). Spec: https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_10th/【Hackathon_10th】开源贡献个人挑战赛春节特别季—任务合集.md#no53.

For models that mix Sliding-Window Attention (SWA) heads with full-attention heads inside the **same layer**, today's V1 KV-cache scheduling path (`ResourceManagerV1` + `PrefixCacheManager`) allocates one shared `block_idx` per layer for **all** heads. SWA heads finish their window long before full-attn heads, but their cache stays pinned until the whole layer evicts. Throughput suffers.

This PR teaches the V1 scheduler + `PrefixCacheManager` to manage `block_idx` **per head** (head-wise SWA layout) and recycle a SWA head's cache as soon as it crosses its window.

## Modifications

| Area | Change |
|---|---|
| `fastdeploy/cache_manager/prefix_cache_manager.py` | Per-request head-wise GPU free list (`gpu_free_block_list_head_wise[head]`); `allocate_gpu_blocks_head_wise` / `recycle_gpu_blocks_head_wise`; TP-aware sizing (`num_key_value_heads // tp_size`) |
| `fastdeploy/engine/sched/resource_manager_v1.py` | `recycle_request_swa_head_cache` (per-head cursor advance ≥ window+sink); `_should_skip_swa_recycle_for_overlap`; P4 cleanup in `_free_blocks` |
| `fastdeploy/model_executor/models/paddleformers/base.py` | Default-off ERNIE SWA fixture gated by `FD_T53_HEAD_WISE_SWA_FIXTURE=1` |
| `fastdeploy/config.py` | Engine-main FDConfig fixture: mirror head-wise SWA attributes so `ResourceManagerV1._should_use_head_wise_swa` sees correct `model_config.head_wise_swa_ratio`. Gated on `FD_T53_HEAD_WISE_SWA_FIXTURE` |
| `custom_ops/gpu_ops/append_attn/` | `block_tables_headwise` optional parameter added to kernel signatures; sentinel guard for recycled blocks |
| Mutual exclusion | `enable_prefix_caching=True + FD_HEAD_WISE_KV_CACHE=1` raises at `PrefixCacheManager.__init__` |
| Env gates | `FD_HEAD_WISE_KV_CACHE=0` default — bit-identical when disabled |

## Usage or Command

```bash
export FD_T53_HEAD_WISE_SWA_FIXTURE=1
export ENABLE_V1_KVCACHE_SCHEDULER=1
export FD_HEAD_WISE_KV_CACHE=1
export FD_T53_HEAD_WISE_SWA_RATIO=1.0
python -m fastdeploy.entrypoints.openai.api_server \
    --model baidu/ERNIE-4.5-21B-A3B-Paddle \
    --max-model-len 32768
```

## Accuracy Tests

**Round 3 (full run — 1024 prompts, A800-80GB):**

| Config | Output throughput (tok/s) | Δ |
|---|---|---|
| head-wise + recycle OFF | 722.93 | baseline |
| head-wise + recycle ON | 1270.87 | **+75.8%** ≥30 ✓ |

Round 3 integrity: `completed=1024/1024`, `errors=0`, mean TTFT improved -48.0%.

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [x] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
📝 PR 规范 标题 包含两个官方 Tag [Feature]+[KVCache] 及非官方前缀/后缀,应只保留一个 Tag
📝 PR 规范 Checklist 使用了自定义 Checklist 条目,未遵循 §D2 标准模板,且全部未勾选(含已完成项)
🟡 建议 fastdeploy/worker/gpu_model_runner.py A6:通用 GPU Worker 路径变更未见同步到 XPU/DCU/HPU/GCU/Iluvatar/Metax 对应 model_runner,可能导致其他硬件 CI 失败
❓ 疑问 fastdeploy/engine/sched/resource_manager_v1.py total_tokens % block_size != 0 守卫在单步解码时约每 block_size 步才允许一次回收,请确认该频率满足性能预期

总体评价

PR 整体设计严谨,env-gated 默认关闭保证了 bit-identical 兼容性,block 生命周期分析完整(分配/回收/teardown 三段逻辑配对,P4 跨请求污染已修复),sentinel guard 注释详尽。主要需关注 gpu_model_runner.py 等通用 worker 路径变更对其他硬件 CI 的影响,以及 PR 规范修复。

@bob-cloudforge
Copy link
Copy Markdown
Author

Added tests/operators/test_head_wise_swa_sentinel_guard.py — a Python shadow oracle that locks the sentinel-guard invariant without requiring GPU hardware. Three methods: (1) test_sink_size_positive_no_sentinel_in_attended_window: verifies no -1 appears in the blocks accessed by a chunk at or after window_start with sink_size > 0; (2) test_sink_size_zero_no_sentinel_at_chunk_start: same check with no sink region; (3) test_recycled_gap_does_not_overlap_kernel_reads: confirms recycled-gap positions [sink_blocks, window_start_block) have no overlap with the _kernel_read_positions set for any valid chunk. The test runs unconditionally in the H1Z1 CI job.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants