From a622ffcbbcb17d756d792bd0fd97a53ca377e6cd Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 12:37:43 -0400 Subject: [PATCH 01/11] fix(qwen36_thor): match parent _long_tq_effective_k signature The shared frontend added an optional max_new_tokens parameter to _long_tq_effective_k and its call sites pass it positionally; the Thor override kept the two-argument form, so every long-context generate on Thor raised TypeError. Accept and forward the new parameter; the Thor K-cap behaviour is unchanged. --- flash_rt/frontends/torch/qwen36_thor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flash_rt/frontends/torch/qwen36_thor.py b/flash_rt/frontends/torch/qwen36_thor.py index 4239b76a..725c0653 100644 --- a/flash_rt/frontends/torch/qwen36_thor.py +++ b/flash_rt/frontends/torch/qwen36_thor.py @@ -977,8 +977,11 @@ def _thor_mtp_prefill_K_nvfp4( # NOT enable TurboQuant; FP8-KV is selected by # ``FLASHRT_QWEN36_LONG_KV_CACHE=fp8``. The env is honoured here # for explicit user overrides (bisection, ablation). - def _long_tq_effective_k(self, prompt_len: int, K: int) -> int: - target_k = super()._long_tq_effective_k(prompt_len, K) + def _long_tq_effective_k( + self, prompt_len: int, K: int, + max_new_tokens: int | None = None) -> int: + target_k = super()._long_tq_effective_k( + prompt_len, K, max_new_tokens) if os.environ.get('FLASHRT_QWEN36_TQ_SPEC_K', ''): return target_k if target_k > 6 and int(prompt_len) >= 12288: From 927063508f4d8e061d628ea84d408bc1c685c7d7 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 12:37:55 -0400 Subject: [PATCH 02/11] feat(qwen36): DFlash speculative decoding on Thor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DFlash generate loop hard-coded the BF16-staged verify forward and the per-position prompt walk, both of which are RTX-shaped: the Thor K-row layer path at S=16 is single-XQA over the persistent FP8 KV cache, so the verify needs the FP8-KV mode flag active and the prompt rows must be present in that cache before the first verify. Split both stages behind arch hooks on the shared frontend, defaulting to the existing behaviour: * _dflash_verify_forward_K — the S=K verify used for warmup and graph capture; default remains forward_own_decode_K_nvfp4. * _dflash_prefill_nvfp4 — prompt prefill returning the first greedy token; default remains the per-position captured-graph walk. Thor overrides route the verify through the FP8-KV wrapper (same as the long-ctx spec verify), prefill through the chunked FP8 prefill (also the fast-TTFT path on Thor), and guarantee the FP8 KV cache exists after drafter load for short-context constructions. Verified on Thor at ctx=128/64 new tokens: greedy tokens identical to the production MTP spec path. --- flash_rt/frontends/torch/qwen36_rtx.py | 54 +++++++++++++++++-------- flash_rt/frontends/torch/qwen36_thor.py | 42 +++++++++++++++++++ 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/flash_rt/frontends/torch/qwen36_rtx.py b/flash_rt/frontends/torch/qwen36_rtx.py index 7fa00cb3..3126e3f9 100644 --- a/flash_rt/frontends/torch/qwen36_rtx.py +++ b/flash_rt/frontends/torch/qwen36_rtx.py @@ -11736,8 +11736,36 @@ def _restore(): self._captured_drafter_graphs_dflash, eff_ctx, g) return g + def _dflash_verify_forward_K(self, token_ids_K, cos_K, sin_K, + cur_pos: int, K: int, tap_buf): + """Arch hook: the S=K verify forward used by DFlash spec decode. + + Default is the BF16-staged KV verify. Subclasses whose K-row + layer path requires a different KV mode (Thor: FP8-KV) override + this with the matching wrapper; the DFlash orchestration and + graph capture above it stay shared. + """ + return self.forward_own_decode_K_nvfp4( + token_ids_K, cos_K, sin_K, cur_pos, K, tap_buf=tap_buf) + + def _dflash_prefill_nvfp4(self, input_ids): + """Arch hook: prompt prefill for DFlash spec decode. + + Default walks the prompt through the per-position S=1 captured + graphs, which writes the BF16 KV cache the default verify + forward reads. Subclasses whose verify attends over a different + KV store (Thor: FP8-KV) override this with a prefill that + populates that store. Returns the (1, 1) first greedy token. + """ + prompt_len = int(input_ids.shape[1]) + for p in range(prompt_len): + self._static_token_id.copy_(input_ids[:, p:p + 1]) + g_pf = self._ensure_graph_for_pos_nvfp4(p) + self._replay_pos_graph(g_pf, p) + return self._logits_buf.argmax(dim=-1, keepdim=True).view(1, 1) + def _ensure_verify_graph_dflash_nvfp4(self, cur_pos: int, K: int): - """Lazy CUDA Graph for forward_own_decode_K_nvfp4 WITH tap_buf. + """Lazy CUDA Graph for the DFlash verify forward WITH tap_buf. Mirror of ``_ensure_verify_graph_nvfp4`` but binds ``tap_buf=self._dflash_taps_buf`` at capture time so the 5 @@ -11777,9 +11805,9 @@ def _restore(): sin_K = self._verify_static_sin[:, :K] tap_buf = self._dflash_taps_buf for _ in range(2): - self.forward_own_decode_K_nvfp4( - tokens_K, cos_K, sin_K, cur_pos, K=K, - tap_buf=tap_buf) + self._dflash_verify_forward_K( + tokens_K, cos_K, sin_K, cur_pos, K, + tap_buf) _restore() gs.synchronize() @@ -11787,9 +11815,9 @@ def _restore(): with torch.cuda.graph( g, stream=gs, pool=self._graph_mempool, ), torch.no_grad(): - self.forward_own_decode_K_nvfp4( - tokens_K, cos_K, sin_K, cur_pos, K=K, - tap_buf=tap_buf) + self._dflash_verify_forward_K( + tokens_K, cos_K, sin_K, cur_pos, K, + tap_buf) with torch.cuda.stream(gs), torch.no_grad(): _restore() gs.synchronize() @@ -11854,15 +11882,9 @@ def generate_own_speculative_DFlash_nvfp4( self._dflash_taps_buf.zero_() with torch.no_grad(): - # 1) Prefill (same as MTP path) — sequential S=1 forwards - # via the per-cur_pos captured S=1 graph. - gs_pf = self._graph_stream - for p in range(prompt_len): - self._static_token_id.copy_(input_ids[:, p:p + 1]) - g_pf = self._ensure_graph_for_pos_nvfp4(p) - self._replay_pos_graph(g_pf, p) - tok = self._logits_buf.argmax( - dim=-1, keepdim=True).view(1, 1) + # 1) Prefill via the arch hook (default: sequential S=1 + # forwards through the per-cur_pos captured graphs). + tok = self._dflash_prefill_nvfp4(input_ids) generated = [tok] cur_pos = prompt_len diff --git a/flash_rt/frontends/torch/qwen36_thor.py b/flash_rt/frontends/torch/qwen36_thor.py index 725c0653..de6e040f 100644 --- a/flash_rt/frontends/torch/qwen36_thor.py +++ b/flash_rt/frontends/torch/qwen36_thor.py @@ -1092,3 +1092,45 @@ def _prefill_mtp_tail_kv_nvfp4( return self._thor_mtp_prefill_K_nvfp4( prev_h_rows, token_ids, pos_start, rows, cache_base_pos=cache_base_pos) + + # ---------- DFlash integration ---------- + # + # DFlash verifies at S=block_size (16), above + # ``_THOR_K_ROW_FAST_PATH_MAX``, so the K-row layers route to + # ``_thor_full_K_forward`` / ``_thor_lin_K_forward`` — and the + # full-attn K-row is single-XQA-path over the persistent FP8 KV + # cache. Three consequences, each handled by one override below: + # the drafter load must guarantee the FP8 cache exists, the prompt + # prefill must populate it, and the verify forward must run with + # the FP8-KV mode flag active. + + def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: + super()._load_dflash_drafter(ckpt_dir) + # Short-ctx constructions (user_max_seq <= LONG_CTX_THRESHOLD) + # never allocate the persistent FP8 KV cache; the Thor DFlash + # verify cannot run without it. + if not hasattr(self, '_fp8_K_cache'): + self._load_fp8_kv_cache(max_seq=self._user_max_seq + 16) + self._long_kv_cache_mode = 'fp8' + + def _dflash_prefill_nvfp4(self, input_ids): + """Thor override: chunked FP8-KV prompt prefill. + + The default per-position walk writes only the BF16 KV cache; + the Thor verify attends over the FP8 cache, so the prompt rows + must land there. The chunked prefill is also the production + Thor TTFT path (batched XQA instead of one forward per token). + """ + _, logits = self._prefill_long_ctx_tq_chunked(input_ids) + return logits.argmax(dim=-1, keepdim=True).view(1, 1) + + def _dflash_verify_forward_K(self, token_ids_K, cos_K, sin_K, + cur_pos: int, K: int, tap_buf): + """Thor override: run the DFlash verify in FP8-KV mode. + + Same wrapper as the production long-ctx spec verify, so the + K-row layer dispatch sees ``_fp8_kv_verify_active`` for the + whole S=K forward. + """ + return self.forward_own_decode_K_nvfp4_fp8kv( + token_ids_K, cos_K, sin_K, cur_pos, K, tap_buf=tap_buf) From 5deebfd36486ffbd0ab29c7061c96917a16a94f4 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 14:29:07 -0400 Subject: [PATCH 03/11] perf(qwen36): constant-time DFlash partial-accept rollback on Thor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DFlash spec loop handled a partial accept by restoring a snapshot and re-advancing the committed rows through a second tapped verify — a second full main-model forward per cycle. On Thor decode is weight- read bound, so with partial accepts on nearly every cycle this doubled the per-cycle weight traffic (measured 152.8 ms/cycle at ctx=128). Split the snapshot and rollback stages behind arch hooks with the existing behaviour as the default, and use the per-step state checkpoint machinery on Thor instead: * _dflash_snap_state / _dflash_partial_rollback hooks on the shared loop; defaults preserve the restore + re-advance flow byte-for-byte. * Thor grows _K_save_max (and the per-step lin/conv checkpoints) to the DFlash verify q_seq at drafter load, dropping any K-row graphs captured against the old buffers. * Thor K-row dispatch delegates BOTH layer types to the parent K-row for K <= _K_save_max. The lin path gains per-step state saves; the full-attn path keeps the verify rows on the same kernel family as the K<=7 re-advance/spec verifies. The latter is required for correctness: committing rows produced by one kernel family while recovery recomputes them with another surfaces their occasional rounding disagreements as greedy divergence (observed once in ~40 cycles, cascading from a single full-attn row). * Thor rollback is then two gpu_copy calls from checkpoint slot N — the same pattern the long-context MTP spec loop uses — and the snapshot stage becomes a no-op. ctx=128 steady state: 152.8 -> 92.9 ms/cycle (+65% decode tok/s at unchanged AL); greedy parity with the production MTP path holds over 256 tokens (3 runs, bit-identical); MTP baseline unchanged. --- flash_rt/frontends/torch/qwen36_rtx.py | 94 +++++++++++++++---------- flash_rt/frontends/torch/qwen36_thor.py | 85 +++++++++++++++++++++- 2 files changed, 140 insertions(+), 39 deletions(-) diff --git a/flash_rt/frontends/torch/qwen36_rtx.py b/flash_rt/frontends/torch/qwen36_rtx.py index 3126e3f9..b4742200 100644 --- a/flash_rt/frontends/torch/qwen36_rtx.py +++ b/flash_rt/frontends/torch/qwen36_rtx.py @@ -11764,6 +11764,59 @@ def _dflash_prefill_nvfp4(self, input_ids): self._replay_pos_graph(g_pf, p) return self._logits_buf.argmax(dim=-1, keepdim=True).view(1, 1) + def _dflash_snap_state(self, cur_pos: int, Kv: int) -> None: + """Arch hook: snapshot state the partial-accept rollback needs. + + Runs on the snap stream, overlapped with the drafter forward. + Subclasses whose rollback reads per-step state checkpoints + written during the verify itself (Thor) override this with a + no-op. + """ + self._snap_lin_buf.copy_(self._lin_state) + self._snap_conv_buf.copy_(self._lin_conv_state) + self._snap_K_buf[:, :Kv].copy_( + self._attn.K_cache[:, cur_pos:cur_pos + Kv]) + self._snap_V_buf[:, :Kv].copy_( + self._attn.V_cache[:, cur_pos:cur_pos + Kv]) + + def _dflash_partial_rollback(self, cur_pos: int, N: int, Kv: int, + tok, drafts, cos_KN, sin_KN) -> None: + """Arch hook: fix up state after a partial accept of N drafts. + + On exit the recurrent/conv state and KV must reflect exactly + the N+1 committed rows [tok, drafts[:N]] at + [cur_pos, cur_pos+N+1), and ``_dflash_taps_buf[:, N]`` must + hold the taps of the last committed row. + + Default: restore the pre-verify snapshot, then re-advance with + the committed rows via a tapped verify at K=N+1 (a second + main-model forward). Subclasses with per-step state saves in + the verify K-row (Thor) override this with constant-time state + copies instead. + """ + import torch + + self._lin_state.copy_(self._snap_lin_buf) + self._lin_conv_state.copy_(self._snap_conv_buf) + self._attn.K_cache[:, cur_pos:cur_pos + Kv].copy_( + self._snap_K_buf[:, :Kv]) + self._attn.V_cache[:, cur_pos:cur_pos + Kv].copy_( + self._snap_V_buf[:, :Kv]) + + Kr = N + 1 + self._verify_static_tokens[:, 0:1].copy_(tok) + if N > 0: + self._verify_static_tokens[:, 1:Kr].copy_( + drafts[:N].view(1, N)) + self._verify_static_cos[:, :Kr].copy_(cos_KN[:, :Kr]) + self._verify_static_sin[:, :Kr].copy_(sin_KN[:, :Kr]) + rg = self._ensure_verify_graph_dflash_nvfp4(cur_pos, Kr) + gs = self._graph_stream + gs.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(gs): + rg.replay() + torch.cuda.current_stream().wait_stream(gs) + def _ensure_verify_graph_dflash_nvfp4(self, cur_pos: int, K: int): """Lazy CUDA Graph for the DFlash verify forward WITH tap_buf. @@ -11901,14 +11954,7 @@ def generate_own_speculative_DFlash_nvfp4( snap_stream = self._snap_stream snap_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(snap_stream): - self._snap_lin_buf.copy_(self._lin_state) - self._snap_conv_buf.copy_(self._lin_conv_state) - self._snap_K_buf[:, :Kv].copy_( - self._attn.K_cache[ - :, cur_pos:cur_pos + Kv]) - self._snap_V_buf[:, :Kv].copy_( - self._attn.V_cache[ - :, cur_pos:cur_pos + Kv]) + self._dflash_snap_state(cur_pos, Kv) # 2b) Drafter forward (P7). # Caller writes static inputs (prev_token + hidden_taps). @@ -11988,38 +12034,12 @@ def generate_own_speculative_DFlash_nvfp4( for j in range(N + 1): if len(generated) < max_new_tokens: generated.append(argmax_at(j)) - # Restore pre-verify state. - self._lin_state.copy_(self._snap_lin_buf) - self._lin_conv_state.copy_(self._snap_conv_buf) - self._attn.K_cache[ - :, cur_pos:cur_pos + Kv].copy_( - self._snap_K_buf[:, :Kv]) - self._attn.V_cache[ - :, cur_pos:cur_pos + Kv].copy_( - self._snap_V_buf[:, :Kv]) - - # Re-advance with N+1 valid inputs via tapped verify - # at K=N+1 (always — including N=0; same code path - # as N>0). Re-uses the dflash verify graph cache. - Kr = N + 1 - rec_cos = cos_KN[:, :Kr] - rec_sin = sin_KN[:, :Kr] - self._verify_static_tokens[:, 0:1].copy_(tok) - if N > 0: - self._verify_static_tokens[:, 1:Kr].copy_( - drafts[:N].view(1, N)) - self._verify_static_cos[:, :Kr].copy_(rec_cos) - self._verify_static_sin[:, :Kr].copy_(rec_sin) - rg = self._ensure_verify_graph_dflash_nvfp4( - cur_pos, Kr) - gs.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(gs): - rg.replay() - torch.cuda.current_stream().wait_stream(gs) + self._dflash_partial_rollback( + cur_pos, N, Kv, tok, drafts, cos_KN, sin_KN) tok = argmax_at(N) self._dflash_taps_buf[:, 0].copy_( self._dflash_taps_buf[:, N]) - cur_pos += Kr + cur_pos += N + 1 if len(generated) > max_new_tokens: generated = generated[:max_new_tokens] diff --git a/flash_rt/frontends/torch/qwen36_thor.py b/flash_rt/frontends/torch/qwen36_thor.py index de6e040f..bbca785a 100644 --- a/flash_rt/frontends/torch/qwen36_thor.py +++ b/flash_rt/frontends/torch/qwen36_thor.py @@ -206,7 +206,13 @@ def _thor_alloc_K_row_scratch(self) -> None: # per-position sub-loop. Bit-exact to running K sequential single- # token forwards (see DESIGN §4.5 for the leaf-kernel set). def _layer_forward_lin_K_nvfp4(self, L, h_in_K, K): - if K <= self._THOR_K_ROW_FAST_PATH_MAX: + # Delegate the whole save-steps range to parent, not just the + # fast-path max: parent's per-step recurrent branch is active + # for K <= _K_save_max, stays per-token-equivalent on SM110, + # and writes the per-step state checkpoints the DFlash + # partial-accept rollback reads (_K_save_max is grown to the + # DFlash verify q_seq at drafter load). + if K <= max(self._THOR_K_ROW_FAST_PATH_MAX, self._K_save_max): return super()._layer_forward_lin_K_nvfp4(L, h_in_K, K) if K > self.MAX_Q_SEQ: return self._thor_lin_K_dispatch(L, h_in_K, K) @@ -214,7 +220,14 @@ def _layer_forward_lin_K_nvfp4(self, L, h_in_K, K): def _layer_forward_full_K_nvfp4( self, L, h_in_K, cos_K, sin_K, cur_pos, K): - if K <= self._THOR_K_ROW_FAST_PATH_MAX: + # Delegate the whole save-steps range to parent, mirroring the + # lin dispatch above. The DFlash spec loop commits per-row + # state/KV from its S=16 verify (slot-copy rollback); rows must + # therefore come from the SAME kernel family as the K<=7 spec + # verifies, or the two paths' occasional rounding disagreements + # surface as greedy divergence (measured: bit-identical for ~39 + # cycles, then a single full-attn row event cascades). + if K <= max(self._THOR_K_ROW_FAST_PATH_MAX, self._K_save_max): return super()._layer_forward_full_K_nvfp4( L, h_in_K, cos_K, sin_K, cur_pos, K) if K > self.MAX_Q_SEQ: @@ -1105,6 +1118,8 @@ def _prefill_mtp_tail_kv_nvfp4( # the FP8-KV mode flag active. def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: + import torch + super()._load_dflash_drafter(ckpt_dir) # Short-ctx constructions (user_max_seq <= LONG_CTX_THRESHOLD) # never allocate the persistent FP8 KV cache; the Thor DFlash @@ -1112,6 +1127,35 @@ def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: if not hasattr(self, '_fp8_K_cache'): self._load_fp8_kv_cache(max_seq=self._user_max_seq + 16) self._long_kv_cache_mode = 'fp8' + # Grow the per-step state checkpoints to the DFlash verify + # q_seq (block_size = _MAX_PUBLIC_SPEC_K + 1). The lin K-row + # save-steps branch then covers the whole verify, and the + # partial-accept rollback becomes two constant-time copies + # instead of a second main-model forward. + needed = self._MAX_PUBLIC_SPEC_K + 1 + if self._K_save_max < needed: + self._K_save_max = needed + self._K_lin_state_per_step = torch.empty( + needed, *self._lin_state.shape, + device=self._lin_state.device, + dtype=self._lin_state.dtype) + self._K_lin_conv_state_per_step = torch.empty( + needed, *self._lin_conv_state.shape, + device=self._lin_conv_state.device, + dtype=self._lin_conv_state.dtype) + # Any K-row graph captured before the grow baked the old + # checkpoint buffers — drop those graphs so they re-capture + # against the new allocations. + for cache_name in ( + '_captured_verify_graphs_fp8kv', + '_captured_prefill_graphs_fp8kv', + '_captured_verify_graphs_tq', + '_captured_prefill_graphs_tq', + '_captured_verify_graphs_dflash', + ): + cache = getattr(self, cache_name, None) + if cache: + cache.clear() def _dflash_prefill_nvfp4(self, input_ids): """Thor override: chunked FP8-KV prompt prefill. @@ -1134,3 +1178,40 @@ def _dflash_verify_forward_K(self, token_ids_K, cos_K, sin_K, """ return self.forward_own_decode_K_nvfp4_fp8kv( token_ids_K, cos_K, sin_K, cur_pos, K, tap_buf=tap_buf) + + def _dflash_snap_state(self, cur_pos: int, Kv: int) -> None: + """Thor override: nothing to snapshot. + + The rollback reads the per-step checkpoints written during the + verify K-row itself. The Thor verify never writes the BF16 KV + cache, and FP8 rows past the accept point are overwritten by + the next verify before any read. + """ + return + + def _dflash_partial_rollback(self, cur_pos: int, N: int, Kv: int, + tok, drafts, cos_KN, sin_KN) -> None: + """Thor override: constant-time state rollback. + + The verify at S=Kv ran the lin K-row save-steps branch + (``Kv <= _K_save_max`` after drafter load), so the state after + every verify row is checkpointed; committing N drafts is a copy + from slot N. Same pattern as the long-ctx MTP spec loop. Taps + for rows <= N are already in ``_dflash_taps_buf`` from the main + verify. + """ + import torch + + from flash_rt import flash_rt_kernels as fvk + + s = torch.cuda.current_stream().cuda_stream + fvk.gpu_copy( + self._lin_state.data_ptr(), + self._K_lin_state_per_step[N].data_ptr(), + self._lin_state.numel() * 2, s, + ) + fvk.gpu_copy( + self._lin_conv_state.data_ptr(), + self._K_lin_conv_state_per_step[N].data_ptr(), + self._lin_conv_state.numel() * 2, s, + ) From d991dc9155ab01d2c948c2990e2670196bdcbf37 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 14:58:50 -0400 Subject: [PATCH 04/11] feat(qwen36): per-token DFlash drafter window with prompt-tail seeding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DFlash drafter window appended ONE fc-projected tap set per spec cycle, so its entries were ~AL committed tokens apart while the drafter attends to them as consecutive positions — starving the drafter of the context features it was trained on and capping the acceptance length well below the MTP chain's. Per-token window mode keeps a feature entry for EVERY committed token: * pertoken_window_append fc-projects the verify tap rows of all committed tokens (N+1 per cycle) and shift-writes them into a fixed-length window, outside the drafter graph. * dflash_drafter_forward_pertoken is a read-only forward over that window; graph capture needs no state snapshot/restore and there is exactly one graph per frontend. * The generate loop gains a _dflash_pertoken_window branch (default off; the shift-window path is untouched). Thor enables it at drafter load: FLASHRT_QWEN36_DFLASH_PERTOKEN (default on), FLASHRT_QWEN36_DFLASH_WINDOW (default 128). * The Thor prefill seeds the window from the prompt tail: the last min(window, prompt) tokens run as a tap-captured chunk, so the drafter starts with real context features instead of ramping from an empty window. FLASHRT_QWEN36_DFLASH_WINDOW_SEED=0 disables the seed (measured to help natural prompts and hurt only degenerate repeated-sentence prompts, where the tail features steer the drafter into repeating the prompt). Thor ctx<=128, steady state, against the FP8-KV MTP reference (greedy parity PASS on all four prompts): robot JSON plan MTP 33.7 tok/s -> DFlash 52.8 tok/s (AL 4.92) robot navigation MTP 30.5 -> DFlash 34.4 (AL 3.20) explain (prose) MTP 28.6 -> DFlash 30.8 (AL 2.87) repeated academic sentence: DFlash 22.9 (drafter is fed its own degenerate context; disable the seed to recover 4.27 AL) --- .../torch/_qwen36_rtx_dflash_forward.py | 139 ++++++++++++++++++ flash_rt/frontends/torch/qwen36_rtx.py | 76 +++++++++- flash_rt/frontends/torch/qwen36_thor.py | 47 +++++- 3 files changed, 258 insertions(+), 4 deletions(-) diff --git a/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py b/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py index ca77de4d..147ebd67 100644 --- a/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py +++ b/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py @@ -799,3 +799,142 @@ def dflash_drafter_forward_capture(frontend) -> torch.Tensor: buf['logits'].data_ptr(), M, VOCAB, H, s, widen=True) return buf['logits'] + + +# ==================================================================== +# Per-token window variant +# ==================================================================== +# +# The shift-window above appends ONE fc-projected tap set per spec +# cycle, so window entries are ~AL committed tokens apart while the +# drafter attends to them at consecutive positions. The per-token +# variant keeps a window of features for EVERY committed token: the +# orchestration appends N+1 entries after each accept (and seeds the +# window from the prompt tail at prefill), and the drafter forward +# below only READS the window — no fc, no shift — which also makes the +# graph capture side-effect free. + +def alloc_pertoken_window(frontend, win: int) -> None: + """Allocate the per-token feature window + append scratch.""" + buf = frontend._dflash_buf + if buf.get('pt_window') is not None and buf['pt_win'] == win: + return + if win > buf['max_ctx']: + raise ValueError( + f'window {win} exceeds drafter max_ctx {buf["max_ctx"]}') + H = buf['hidden'] + dev = frontend.device + buf['pt_window'] = torch.zeros( + win, H, dtype=torch.bfloat16, device=dev) + buf['pt_shift_scratch'] = torch.empty_like(buf['pt_window']) + buf['pt_proj_out'] = torch.empty( + buf['max_ctx'], H, dtype=torch.bfloat16, device=dev) + buf['pt_taps_rows'] = torch.empty( + max(buf['block'], win), 5, H, dtype=torch.bfloat16, device=dev) + buf['pt_seed_taps'] = torch.empty( + 5, win, H, dtype=torch.bfloat16, device=dev) + buf['pt_win'] = win + buf['pt_valid'] = 0 + + +def reset_pertoken_window(frontend) -> None: + """Clear per-token window state. Call at the start of a generate.""" + buf = frontend._dflash_buf + if buf.get('pt_window') is not None: + buf['pt_window'].zero_() + buf['pt_valid'] = 0 + + +def pertoken_window_append(frontend, taps_rows) -> None: + """Append fc-projected features of R committed rows to the window. + + taps_rows: (R, 5, hidden) bf16 — verify tap_buf rows of the + committed tokens, oldest first. Shift-left by R, write the R new + features at the tail. Runs eagerly on the current stream, outside + the drafter graph. + """ + from flash_rt import flash_rt_kernels as fvk + + buf = frontend._dflash_buf + d = frontend._weights.ptrs['dflash'] + s = torch.cuda.current_stream().cuda_stream + H = buf['hidden'] + FC_IN = buf['fc_in'] + eps = float(d['rms_norm_eps']) + win = buf['pt_window'] + W = buf['pt_win'] + R = int(taps_rows.shape[0]) + if R > W: + taps_rows = taps_rows[-W:] + R = W + + x = taps_rows.reshape(R, FC_IN).contiguous() + ap_t, sf_t = buf['act_Mctx_K5120'] + _quant_act(fvk, x, ap_t, sf_t, R, FC_IN, s) + _gemm_nvfp4(fvk, ap_t.data_ptr(), sf_t.data_ptr(), + d['fc_packed'], d['fc_sf'], d['fc_alpha'], + buf['pt_proj_out'].data_ptr(), R, H, FC_IN, s) + if R < W: + scratch = buf['pt_shift_scratch'] + scratch[:W - R].copy_(win[R:]) + win[:W - R].copy_(scratch[:W - R]) + fvk.rms_norm( + int(buf['pt_proj_out'].data_ptr()), int(d['hidden_norm_w']), + int(win[W - R:W].data_ptr()), + R, H, eps, int(s), + ) + buf['pt_valid'] = min(buf['pt_valid'] + R, W) + + +def dflash_drafter_forward_pertoken(frontend, + valid_ctx: int | None = None): + """Drafter forward over the per-token window (read-only). + + valid_ctx: number of valid tail rows to attend to. None means the + full window — the shape the captured graph bakes in. Callers pass + the actual valid count during ramp-up (window not yet full). + + Returns: logits (block, vocab) bf16 in buf['logits']. + """ + from flash_rt import flash_rt_kernels as fvk + + s = torch.cuda.current_stream().cuda_stream + buf = frontend._dflash_buf + d = frontend._weights.ptrs['dflash'] + M = buf['block'] + H = buf['hidden'] + VOCAB = buf['vocab'] + eps = float(d['rms_norm_eps']) + W = buf['pt_win'] + ctx_len = W if valid_ctx is None else int(valid_ctx) + if not (1 <= ctx_len <= W): + raise ValueError(f'valid_ctx={ctx_len} out of [1, {W}]') + win = buf['pt_window'][W - ctx_len:W] + + fvk.qwen36_embedding_lookup_bf16( + buf['ids_static'].data_ptr(), + int(frontend._weights.ptrs['embed_w']), + buf['embed_buf'].data_ptr(), M, H, s, + ) + fvk.gpu_copy( + buf['h_b'].data_ptr(), buf['embed_buf'].data_ptr(), + M * H * 2, s, + ) + h = buf['h_b'] + for L in range(buf['n_layers']): + h = _drafter_layer_forward( + frontend, fvk, L, h, win, ctx_len, s) + fvk.rms_norm( + int(h.data_ptr()), int(d['final_norm_w']), + int(buf['h_final_norm'].data_ptr()), + M, H, eps, int(s), + ) + ap_lm, sf_lm = buf['act_M16_K5120'] + _quant_act(fvk, buf['h_final_norm'], ap_lm, sf_lm, M, H, s) + _gemm_nvfp4(fvk, ap_lm.data_ptr(), sf_lm.data_ptr(), + frontend._weights.ptrs['lm_head_packed'], + frontend._weights.ptrs['lm_head_sf'], + frontend._weights.ptrs['lm_head_alpha'], + buf['logits'].data_ptr(), + M, VOCAB, H, s, widen=True) + return buf['logits'] diff --git a/flash_rt/frontends/torch/qwen36_rtx.py b/flash_rt/frontends/torch/qwen36_rtx.py index b4742200..e5c824f2 100644 --- a/flash_rt/frontends/torch/qwen36_rtx.py +++ b/flash_rt/frontends/torch/qwen36_rtx.py @@ -11817,6 +11817,40 @@ def _dflash_partial_rollback(self, cur_pos: int, N: int, Kv: int, rg.replay() torch.cuda.current_stream().wait_stream(gs) + def _ensure_drafter_graph_dflash_pertoken(self): + """Lazy CUDA Graph for the per-token-window drafter forward. + + The forward is read-only over the window (updates happen + outside the graph via ``pertoken_window_append``), so capture + needs no state snapshot/restore. One graph per frontend — the + window length is fixed at alloc time. + """ + import torch + + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( + dflash_drafter_forward_pertoken, + ) + + g = getattr(self, '_captured_drafter_graph_pertoken', None) + if g is not None: + return g + + gs = self._graph_stream + gs.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(gs), torch.no_grad(): + for _ in range(2): + dflash_drafter_forward_pertoken(self) + gs.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph( + g, stream=gs, pool=self._graph_mempool, + ), torch.no_grad(): + dflash_drafter_forward_pertoken(self) + gs.synchronize() + torch.cuda.current_stream().wait_stream(gs) + self._captured_drafter_graph_pertoken = g + return g + def _ensure_verify_graph_dflash_nvfp4(self, cur_pos: int, K: int): """Lazy CUDA Graph for the DFlash verify forward WITH tap_buf. @@ -11930,6 +11964,20 @@ def generate_own_speculative_DFlash_nvfp4( eff_ctx = int(getattr(self, '_dflash_eff_ctx', 16)) alloc_drafter_capture_window(self, eff_ctx) reset_drafter_capture_state(self) + # Per-token window mode: the drafter attends to fc-projected + # features of every committed token instead of one entry per + # spec cycle. The prefill hook may seed the window from the + # prompt tail. + pertoken = bool(getattr(self, '_dflash_pertoken_window', False)) + if pertoken: + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 + alloc_pertoken_window, + pertoken_window_append, + reset_pertoken_window, + ) + alloc_pertoken_window( + self, int(getattr(self, '_dflash_pertoken_win', 128))) + reset_pertoken_window(self) # Initialize taps to zero — first drafter call gets no real # signal; AL on cycle 0 will be lower than steady-state. self._dflash_taps_buf.zero_() @@ -11964,15 +12012,29 @@ def generate_own_speculative_DFlash_nvfp4( # (avoids zero-dilution that hurts AL). Once the window # is full, replay the captured graph. self._dflash_buf['ids_static'][0:1].copy_(tok.view(1)) - self._dflash_buf['hidden_taps_static'].copy_( - self._dflash_taps_buf[:, 0]) - if self._spec_attempts < eff_ctx: + if pertoken: + valid = int(self._dflash_buf['pt_valid']) + if valid < int(self._dflash_buf['pt_win']): + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 + dflash_drafter_forward_pertoken, + ) + dflash_drafter_forward_pertoken( + self, max(1, valid)) + else: + drafter_g = ( + self._ensure_drafter_graph_dflash_pertoken()) + drafter_g.replay() + elif self._spec_attempts < eff_ctx: from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 dflash_drafter_forward_capture_eager, ) + self._dflash_buf['hidden_taps_static'].copy_( + self._dflash_taps_buf[:, 0]) valid_ctx = self._spec_attempts + 1 dflash_drafter_forward_capture_eager(self, valid_ctx) else: + self._dflash_buf['hidden_taps_static'].copy_( + self._dflash_taps_buf[:, 0]) drafter_g = self._ensure_drafter_graph_dflash_nvfp4( eff_ctx) drafter_g.replay() @@ -12040,6 +12102,14 @@ def generate_own_speculative_DFlash_nvfp4( self._dflash_taps_buf[:, 0].copy_( self._dflash_taps_buf[:, N]) cur_pos += N + 1 + if pertoken: + # Window gains one feature per state-advanced row + # (N+1 on partial accept, Kv on full accept). + R = (Kv if N == K else N + 1) + rows = self._dflash_buf['pt_taps_rows'][:R] + rows.copy_( + self._dflash_taps_buf[:, :R].permute(1, 0, 2)) + pertoken_window_append(self, rows) if len(generated) > max_new_tokens: generated = generated[:max_new_tokens] diff --git a/flash_rt/frontends/torch/qwen36_thor.py b/flash_rt/frontends/torch/qwen36_thor.py index bbca785a..586f24b3 100644 --- a/flash_rt/frontends/torch/qwen36_thor.py +++ b/flash_rt/frontends/torch/qwen36_thor.py @@ -1156,6 +1156,16 @@ def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: cache = getattr(self, cache_name, None) if cache: cache.clear() + # Per-token drafter window (default on for Thor): the drafter + # attends to fc-projected features of every committed token. + # Measured on Thor at ctx=128: steady AL 2.53 -> 3.49 vs the + # one-entry-per-cycle shift window. + if not hasattr(self, '_dflash_pertoken_window'): + self._dflash_pertoken_window = os.environ.get( + 'FLASHRT_QWEN36_DFLASH_PERTOKEN', '1', + ).strip().lower() not in ('0', 'false', 'off') + self._dflash_pertoken_win = int(os.environ.get( + 'FLASHRT_QWEN36_DFLASH_WINDOW', '128') or '128') def _dflash_prefill_nvfp4(self, input_ids): """Thor override: chunked FP8-KV prompt prefill. @@ -1164,8 +1174,43 @@ def _dflash_prefill_nvfp4(self, input_ids): the Thor verify attends over the FP8 cache, so the prompt rows must land there. The chunked prefill is also the production Thor TTFT path (batched XQA instead of one forward per token). + + In per-token-window mode the last min(window, prompt) tokens + run as a separate tap-captured chunk so the drafter window + starts seeded with the prompt tail's features instead of + ramping from empty. """ - _, logits = self._prefill_long_ctx_tq_chunked(input_ids) + seed_window = ( + getattr(self, '_dflash_pertoken_window', False) + and os.environ.get( + 'FLASHRT_QWEN36_DFLASH_WINDOW_SEED', '1', + ).strip().lower() not in ('0', 'false', 'off')) + if not seed_window: + _, logits = self._prefill_long_ctx_tq_chunked(input_ids) + return logits.argmax(dim=-1, keepdim=True).view(1, 1) + + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( + alloc_pertoken_window, + pertoken_window_append, + ) + + alloc_pertoken_window( + self, int(getattr(self, '_dflash_pertoken_win', 128))) + buf = self._dflash_buf + P = int(input_ids.shape[1]) + tail = min(int(buf['pt_win']), P) + if P > tail: + self._prefill_long_ctx_tq_chunked(input_ids[:, :P - tail]) + d = self._rope_dim + cos_T = self._rope_cos_table[P - tail:P].view(1, tail, d) + sin_T = self._rope_sin_table[P - tail:P].view(1, tail, d) + seed = buf['pt_seed_taps'] + logits = self.forward_own_decode_K_nvfp4_fp8kv( + input_ids[:, P - tail:], cos_T, sin_T, P - tail, tail, + tap_buf=seed, logits_mode='last') + rows = buf['pt_taps_rows'][:tail] + rows.copy_(seed[:, :tail].permute(1, 0, 2)) + pertoken_window_append(self, rows) return logits.argmax(dim=-1, keepdim=True).view(1, 1) def _dflash_verify_forward_K(self, token_ids_K, cos_K, sin_K, From 2f66a6105963be3f985f6c841268e7e3291bb615 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 15:20:10 -0400 Subject: [PATCH 05/11] docs(qwen36): DFlash usage, window knobs, and Thor performance Add docs/qwen36_dflash.md covering the drafter checkpoint, the generate entry point, the per-token context window and its env knobs, measured Thor numbers against the FP8-KV MTP reference, and the benchmark caveats (seed vs verbatim-repeated prompts, KV-format-matched parity references). Register the new env vars in qwen36_usage.md, correct the stale init_dflash_drafter reference, and add the README news entry. --- README.md | 1 + docs/qwen36_dflash.md | 118 ++++++++++++++++++++++++++++++++++++++++++ docs/qwen36_usage.md | 5 +- 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 docs/qwen36_dflash.md diff --git a/README.md b/README.md index a503c5f4..5c04b544 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ See [Supported Models](#supported-models), [Hardware Support](#hardware-support) ## News +- [2026/07] **DFlash block-diffusion speculative decoding** for Qwen3.6-27B NVFP4 on Jetson AGX Thor: one 15-token drafter block per cycle, constant-time partial-accept rollback, and a per-token drafter context window with prompt-tail seeding — **52.8 tok/s on structured robot-plan prompts vs 33.7 tok/s for the MTP chain (+57%)**, lossless greedy output. See [Qwen3.6 DFlash](docs/qwen36_dflash.md). - [2026/06] **Higgs Audio v3 TTS-4B** lands on FlashRT with a kernelized FP8/BF16 decode path, streaming-friendly generation API, and a FastAPI serving host. See [Higgs usage](docs/higgs_audio_v3.md#3-quickstart), [Higgs performance](docs/higgs_audio_v3.md#performance), and [Higgs serving](serving/higgs_audio_agent/README.md). - [2026/06] **FlashRT HF Kernels** are available as Hugging Face Kernel Hub packages under the `flashrt` namespace. See [LiangSu8899/FlashRT-HF-kernels](https://github.com/LiangSu8899/FlashRT-HF-kernels) and [huggingface.co/flashrt](https://huggingface.co/flashrt). - [2026/06] The `serving/` layer is documented as the scenario-host layer for OpenAI-compatible LLM/audio serving and robot execution-state hosts. See [serving README](serving/README.md), [serving design](docs/serving_design.md), and [architecture](docs/architecture.md). diff --git a/docs/qwen36_dflash.md b/docs/qwen36_dflash.md new file mode 100644 index 00000000..4ef9848b --- /dev/null +++ b/docs/qwen36_dflash.md @@ -0,0 +1,118 @@ +# Qwen3.6-27B DFlash Speculative Decoding + +This document covers the DFlash block-diffusion drafter path for +Qwen3.6-27B NVFP4. DFlash replaces the sequential MTP draft chain with +a single drafter forward per speculation cycle: a 5-layer 2B drafter +proposes an entire 15-token block, and the target model verifies the +block in one S=16 forward. + +For the general Qwen3.6 NVFP4 model contract and parameter reference, +see [`qwen36_nvfp4.md`](qwen36_nvfp4.md) and +[`qwen36_usage.md`](qwen36_usage.md). + +## Requirements + +- Qwen3.6-27B NVFP4 main checkpoint (same as the MTP path). +- The z-lab DFlash drafter checkpoint: + +```bash +hf download z-lab/Qwen3.6-27B-DFlash --local-dir /models/Qwen3.6-27B-DFlash +``` + + The drafter ships as a single BF16 `model.safetensors` (~3.3 GB, + 5 layers, `block_size=16`, target hidden taps at layers + 1/16/31/46/61). FlashRT quantizes every drafter linear to NVFP4 at + load time (~825 MB resident); no separate conversion step. +- On Thor the DFlash verify runs over the persistent FP8 KV cache. + The frontend allocates it automatically at drafter load if the + construction did not already enable long-context mode. + +## Usage + +```python +import os + +from flash_rt.frontends.torch.qwen36_thor import Qwen36TorchFrontendThor + +os.environ["FLASHRT_QWEN36_MTP_CKPT_DIR"] = "/models/Qwen3.6-27B-FP8" +os.environ["FLASHRT_QWEN36_DFLASH_CKPT_DIR"] = "/models/Qwen3.6-27B-DFlash" +os.environ["FLASHRT_QWEN36_LONG_KV_CACHE"] = "fp8" + +fe = Qwen36TorchFrontendThor( + "/models/Qwen3.6-27B-NVFP4", + quant="nvfp4", + max_seq=32768, +) +fe._load_dflash_drafter() # reads FLASHRT_QWEN36_DFLASH_CKPT_DIR + +ids = fe._tokenizer.apply_chat_template( + [{"role": "user", "content": "Plan the pick-and-place task."}], + add_generation_prompt=True, return_tensors="pt").to(fe.device) + +out = fe.generate_own_speculative_DFlash_nvfp4( + ids, + max_new_tokens=256, + K=15, # speculative tokens per cycle +) +``` + +The RTX frontend exposes the same entry point; the drafter and verify +kernels are shared, only the KV plumbing differs per arch. + +## Drafter context window + +The drafter conditions on fc-projected target hidden features of the +committed context. Two window modes exist: + +- **Per-token window** (Thor default): one feature entry per committed + token, appended in bulk after each verify (N+1 entries per cycle). + On Thor the prompt prefill seeds the window with the features of the + last `min(window, prompt_len)` prompt tokens, so the drafter starts + at full context instead of ramping from empty. +- **Per-cycle shift window** (legacy, RTX default): one entry per + speculation cycle. Kept for compatibility; acceptance length is + measurably lower because window entries end up ~AL tokens apart. + +| Env | Default | Meaning | +|---|---|---| +| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | unset | Drafter checkpoint directory (required). | +| `FLASHRT_QWEN36_DFLASH_PERTOKEN` | `1` on Thor | Per-token window mode. | +| `FLASHRT_QWEN36_DFLASH_WINDOW` | `128` | Per-token window length (tokens, <= 256). | +| `FLASHRT_QWEN36_DFLASH_WINDOW_SEED` | `1` | Seed the window from the prompt tail at prefill (Thor). | + +## Measured performance (Thor, SM110) + +Steady-state decode at short context against the FP8-KV MTP spec path +(`generate_own_speculative_KN_nvfp4`, K=6) in the same process, greedy +decoding, 64/256-token delta method: + +| prompt | MTP AL / tok/s | DFlash AL / tok/s | +|---|---:|---:| +| robot task -> JSON plan | 2.87 / 33.7 | **4.92 / 52.8** | +| robot navigation plan | 2.59 / 30.5 | 3.20 / 34.4 | +| prose explanation | 2.43 / 28.6 | 2.87 / 30.8 | + +Cycle anatomy on Thor: one S=16 verify (~86 ms, weight-read bound) + +one drafter graph replay (~7 ms). A partial accept costs two +constant-time state copies from the per-step checkpoints written +during the verify itself — there is no recovery forward. + +Output quality is lossless: the verify pass is the greedy ground +truth, and generated tokens are byte-identical to the FP8-KV MTP +reference on all measured prompts. + +## Notes + +- Structured output (JSON plans, code) accepts much better than free + prose; the gains above track the drafter's training distribution. +- Degenerate prompts that repeat one sentence verbatim can steer the + seeded window into drafting more repetition. If you benchmark with + synthetic repeated text, disable the seed + (`FLASHRT_QWEN36_DFLASH_WINDOW_SEED=0`) for representative numbers. +- Greedy-parity comparisons must use the FP8-KV MTP route as the + reference (`FLASHRT_QWEN36_LONG_CTX_ROUTE_MIN_SEQ=0` forces it for + short prompts). The BF16 short route stores KV in a different + format, so token-exact comparison across the two is not meaningful. +- The published drafter checkpoint is marked by z-lab as still under + training; acceptance lengths should improve by dropping in a newer + drafter checkpoint without code changes. diff --git a/docs/qwen36_usage.md b/docs/qwen36_usage.md index 308b0f08..e0123b60 100644 --- a/docs/qwen36_usage.md +++ b/docs/qwen36_usage.md @@ -186,7 +186,10 @@ frontend is built has no effect. | `FLASHRT_QWEN36_MTP_CKPT_DIR` | Required for spec decode | unset | Directory containing `mtp.safetensors` (FP8 e4m3 block-128) from a paired Qwen3.6-Next-27B-FP8 ckpt. Loaded once at construction and converted FP8 → BF16 → NVFP4. If unset, MTP is `None` and `generate_own_speculative_KN_nvfp4` raises; pure-decode still works. | | `FLASHRT_QWEN36_MTP_KEEP_BF16` | Optional | BF16-source MTP: `1`; FP8-source MTP: n/a | For community BF16/native MTP checkpoints, keep BF16 projection weights and use them in the drafter hot path. This improves MTP alignment at the cost of extra VRAM. Set `0` to force the lower-memory NVFP4-converted MTP path. | | `FLASHRT_QWEN36_HF_PATCH` | Optional | unset | Path to a HF FP8 dispatch monkey-patch script. Only consulted by the legacy FP8 path; the NVFP4 path doesn't need it. If unset or path doesn't exist, the patch step is silently skipped. | -| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | Optional | unset | Drafter ckpt directory for the DFlash add-on path. Required only if you call `init_dflash_drafter()`; raises a clear error if unset and `ckpt_dir` is also not passed. | +| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | Optional | unset | Drafter ckpt directory for the DFlash path. Required only if you call `_load_dflash_drafter()`; raises a clear error if unset and `ckpt_dir` is also not passed. See [`qwen36_dflash.md`](qwen36_dflash.md). | +| `FLASHRT_QWEN36_DFLASH_PERTOKEN` | Optional | `1` on Thor | Per-token drafter context window (one feature entry per committed token). `0` falls back to the legacy per-cycle shift window. See [`qwen36_dflash.md`](qwen36_dflash.md). | +| `FLASHRT_QWEN36_DFLASH_WINDOW` | Optional | `128` | Per-token drafter window length in tokens (max 256). | +| `FLASHRT_QWEN36_DFLASH_WINDOW_SEED` | Optional | `1` | Seed the per-token window from the prompt tail during Thor prefill. Disable for benchmarks built from verbatim-repeated text. | | `FLASHRT_QWEN36_MAX_Q_SEQ` | Optional | `2048` | Maximum S=K working-set rows for verify/prefill buffers. Long prefill chunking is additionally capped by the retained BF16 working window. | | `FLASHRT_QWEN36_LONG_CTX_BF16_WINDOW` | Optional | `min(2048, MAX_Q_SEQ)` | Retained BF16 working-window rows in long-context mode. Raising this can enable larger prompt chunks but costs substantial VRAM. | | `FLASHRT_QWEN36_LONG_CTX_ROUTE_MIN_SEQ` | Optional | `512` in long-ctx mode | Prompt length at or above which a long-context frontend routes through the chunked compressed-KV path. The measured 128-token bucket is also routed through FP8-KV to avoid the legacy one-token BF16/spec prefill. Other short prompts stay on BF16/spec unless the full request exceeds the retained BF16 window. | From bdad0721bfd7bec133b76d2cb910b55464b77603 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 15:44:14 -0400 Subject: [PATCH 06/11] fix(qwen36): commit per-token window rows before the tap shuffle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The per-token window append ran AFTER the end-of-cycle taps[:, 0] <- taps[:, N] shuffle, so the first committed row's feature entered the window as a duplicate of the last accepted row on every cycle — the window was not one-feature-per-committed-token as documented. Extract the append into _dflash_window_commit(N) and call it before the shuffle; the shuffle itself is hoisted out of the accept branches (N == K on a full accept, so a single taps[:, N] copy covers both, byte-identical for the non-per-token default path). Add structural tests (no checkpoint / no GPU): window-commit row order and copy semantics, a source-order guard on the generate loop, the missing-drafter fail-fast, the public drafter-init delegation, and Thor per-token env routing (default on, opt-out, window override). Re-measured on Thor with the corrected window (parity PASS on all prompts): robot JSON 48.9 tok/s (AL 4.57) vs MTP 33.7; robot navigation 34.8 (3.25); prose 31.7 (3.00). Docs and README updated to the corrected numbers. --- README.md | 2 +- docs/qwen36_dflash.md | 32 ++++- docs/qwen36_usage.md | 2 +- flash_rt/frontends/torch/qwen36_rtx.py | 46 +++++--- tests/test_qwen36_dflash_structural.py | 154 +++++++++++++++++++++++++ 5 files changed, 216 insertions(+), 20 deletions(-) create mode 100644 tests/test_qwen36_dflash_structural.py diff --git a/README.md b/README.md index 5c04b544..fa60db7a 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ See [Supported Models](#supported-models), [Hardware Support](#hardware-support) ## News -- [2026/07] **DFlash block-diffusion speculative decoding** for Qwen3.6-27B NVFP4 on Jetson AGX Thor: one 15-token drafter block per cycle, constant-time partial-accept rollback, and a per-token drafter context window with prompt-tail seeding — **52.8 tok/s on structured robot-plan prompts vs 33.7 tok/s for the MTP chain (+57%)**, lossless greedy output. See [Qwen3.6 DFlash](docs/qwen36_dflash.md). +- [2026/07] **DFlash block-diffusion speculative decoding** for Qwen3.6-27B NVFP4 on Jetson AGX Thor: one 15-token drafter block per cycle, constant-time partial-accept rollback, and a per-token drafter context window with prompt-tail seeding — **48.9 tok/s on structured robot-plan prompts vs 33.7 tok/s for the MTP chain (+45%)**, lossless greedy output. See [Qwen3.6 DFlash](docs/qwen36_dflash.md). - [2026/06] **Higgs Audio v3 TTS-4B** lands on FlashRT with a kernelized FP8/BF16 decode path, streaming-friendly generation API, and a FastAPI serving host. See [Higgs usage](docs/higgs_audio_v3.md#3-quickstart), [Higgs performance](docs/higgs_audio_v3.md#performance), and [Higgs serving](serving/higgs_audio_agent/README.md). - [2026/06] **FlashRT HF Kernels** are available as Hugging Face Kernel Hub packages under the `flashrt` namespace. See [LiangSu8899/FlashRT-HF-kernels](https://github.com/LiangSu8899/FlashRT-HF-kernels) and [huggingface.co/flashrt](https://huggingface.co/flashrt). - [2026/06] The `serving/` layer is documented as the scenario-host layer for OpenAI-compatible LLM/audio serving and robot execution-state hosts. See [serving README](serving/README.md), [serving design](docs/serving_design.md), and [architecture](docs/architecture.md). diff --git a/docs/qwen36_dflash.md b/docs/qwen36_dflash.md index 4ef9848b..6c2f883c 100644 --- a/docs/qwen36_dflash.md +++ b/docs/qwen36_dflash.md @@ -43,7 +43,7 @@ fe = Qwen36TorchFrontendThor( quant="nvfp4", max_seq=32768, ) -fe._load_dflash_drafter() # reads FLASHRT_QWEN36_DFLASH_CKPT_DIR +fe.init_dflash_drafter() # reads FLASHRT_QWEN36_DFLASH_CKPT_DIR ids = fe._tokenizer.apply_chat_template( [{"role": "user", "content": "Plan the pick-and-place task."}], @@ -88,19 +88,41 @@ decoding, 64/256-token delta method: | prompt | MTP AL / tok/s | DFlash AL / tok/s | |---|---:|---:| -| robot task -> JSON plan | 2.87 / 33.7 | **4.92 / 52.8** | -| robot navigation plan | 2.59 / 30.5 | 3.20 / 34.4 | -| prose explanation | 2.43 / 28.6 | 2.87 / 30.8 | +| robot task -> JSON plan | 2.87 / 33.7 | **4.57 / 48.9** | +| robot navigation plan | 2.59 / 30.5 | 3.25 / 34.8 | +| prose explanation | 2.43 / 28.5 | 3.00 / 31.7 | Cycle anatomy on Thor: one S=16 verify (~86 ms, weight-read bound) + one drafter graph replay (~7 ms). A partial accept costs two constant-time state copies from the per-step checkpoints written -during the verify itself — there is no recovery forward. +during the verify itself — there is no recovery forward. The accept +decision includes one host synchronization per cycle +(`argmin().item()` on the match mask); at ~10 us it is three orders +of magnitude below the verify cost and is included in every number +above. A device-side accept loop is possible follow-up work, not a +prerequisite. Output quality is lossless: the verify pass is the greedy ground truth, and generated tokens are byte-identical to the FP8-KV MTP reference on all measured prompts. +## Serving + +A stateless OpenAI-compatible host for this path lives in +[`serving/qwen36_dflash_agent`](../serving/qwen36_dflash_agent) — +single-stream request/response serving with per-request DFlash +generation and accept-length telemetry: + +```bash +python -m serving.qwen36_dflash_agent.server \ + --checkpoint /models/Qwen3.6-27B-NVFP4 --max-seq 32768 --K 15 +curl -s http://127.0.0.1:8000/health +``` + +For long-running agent sessions (prefix reuse, tool calling, SSE +streaming) use [`serving/qwen36_agent`](../serving/qwen36_agent), +which serves the MTP spec path. + ## Notes - Structured output (JSON plans, code) accepts much better than free diff --git a/docs/qwen36_usage.md b/docs/qwen36_usage.md index e0123b60..3efc2738 100644 --- a/docs/qwen36_usage.md +++ b/docs/qwen36_usage.md @@ -186,7 +186,7 @@ frontend is built has no effect. | `FLASHRT_QWEN36_MTP_CKPT_DIR` | Required for spec decode | unset | Directory containing `mtp.safetensors` (FP8 e4m3 block-128) from a paired Qwen3.6-Next-27B-FP8 ckpt. Loaded once at construction and converted FP8 → BF16 → NVFP4. If unset, MTP is `None` and `generate_own_speculative_KN_nvfp4` raises; pure-decode still works. | | `FLASHRT_QWEN36_MTP_KEEP_BF16` | Optional | BF16-source MTP: `1`; FP8-source MTP: n/a | For community BF16/native MTP checkpoints, keep BF16 projection weights and use them in the drafter hot path. This improves MTP alignment at the cost of extra VRAM. Set `0` to force the lower-memory NVFP4-converted MTP path. | | `FLASHRT_QWEN36_HF_PATCH` | Optional | unset | Path to a HF FP8 dispatch monkey-patch script. Only consulted by the legacy FP8 path; the NVFP4 path doesn't need it. If unset or path doesn't exist, the patch step is silently skipped. | -| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | Optional | unset | Drafter ckpt directory for the DFlash path. Required only if you call `_load_dflash_drafter()`; raises a clear error if unset and `ckpt_dir` is also not passed. See [`qwen36_dflash.md`](qwen36_dflash.md). | +| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | Optional | unset | Drafter ckpt directory for the DFlash path. Required only if you call `init_dflash_drafter()`; raises a clear error if unset and `ckpt_dir` is also not passed. See [`qwen36_dflash.md`](qwen36_dflash.md). | | `FLASHRT_QWEN36_DFLASH_PERTOKEN` | Optional | `1` on Thor | Per-token drafter context window (one feature entry per committed token). `0` falls back to the legacy per-cycle shift window. See [`qwen36_dflash.md`](qwen36_dflash.md). | | `FLASHRT_QWEN36_DFLASH_WINDOW` | Optional | `128` | Per-token drafter window length in tokens (max 256). | | `FLASHRT_QWEN36_DFLASH_WINDOW_SEED` | Optional | `1` | Seed the per-token window from the prompt tail during Thor prefill. Disable for benchmarks built from verbatim-repeated text. | diff --git a/flash_rt/frontends/torch/qwen36_rtx.py b/flash_rt/frontends/torch/qwen36_rtx.py index e5c824f2..1f5cd91f 100644 --- a/flash_rt/frontends/torch/qwen36_rtx.py +++ b/flash_rt/frontends/torch/qwen36_rtx.py @@ -11614,6 +11614,14 @@ def _tq_inject_kv(self, full_rank: int, cur_pos: int, # N6-A4: DFlash spec decode (block-diffusion drafter + chain verify) # ================================================================== + def init_dflash_drafter(self, ckpt_dir: str | None = None) -> None: + """Public entry: load the DFlash drafter for spec decode. + + ``ckpt_dir`` falls back to ``FLASHRT_QWEN36_DFLASH_CKPT_DIR``. + Must be called before ``generate_own_speculative_DFlash_nvfp4``. + """ + self._load_dflash_drafter(ckpt_dir) + def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: """Load the z-lab/Qwen3.6-27B-DFlash drafter (NVFP4 W4A16). @@ -11764,6 +11772,23 @@ def _dflash_prefill_nvfp4(self, input_ids): self._replay_pos_graph(g_pf, p) return self._logits_buf.argmax(dim=-1, keepdim=True).view(1, 1) + def _dflash_window_commit(self, N: int) -> None: + """Append the committed rows' features to the per-token window. + + Tap rows 0..N are the state-advanced verify rows + [tok, drafts[:N]], oldest first. Callers must invoke this + BEFORE the end-of-cycle taps[:, 0] shuffle, which overwrites + row 0 with row N. + """ + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( + pertoken_window_append, + ) + + R = N + 1 + rows = self._dflash_buf['pt_taps_rows'][:R] + rows.copy_(self._dflash_taps_buf[:, :R].permute(1, 0, 2)) + pertoken_window_append(self, rows) + def _dflash_snap_state(self, cur_pos: int, Kv: int) -> None: """Arch hook: snapshot state the partial-accept rollback needs. @@ -11972,7 +11997,6 @@ def generate_own_speculative_DFlash_nvfp4( if pertoken: from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 alloc_pertoken_window, - pertoken_window_append, reset_pertoken_window, ) alloc_pertoken_window( @@ -12088,9 +12112,6 @@ def generate_own_speculative_DFlash_nvfp4( if len(generated) < max_new_tokens: generated.append(argmax_at(j)) tok = argmax_at(K) - # Move taps[K] -> taps[0] for next cycle - self._dflash_taps_buf[:, 0].copy_( - self._dflash_taps_buf[:, K]) cur_pos += Kv else: for j in range(N + 1): @@ -12099,17 +12120,16 @@ def generate_own_speculative_DFlash_nvfp4( self._dflash_partial_rollback( cur_pos, N, Kv, tok, drafts, cos_KN, sin_KN) tok = argmax_at(N) - self._dflash_taps_buf[:, 0].copy_( - self._dflash_taps_buf[:, N]) cur_pos += N + 1 if pertoken: - # Window gains one feature per state-advanced row - # (N+1 on partial accept, Kv on full accept). - R = (Kv if N == K else N + 1) - rows = self._dflash_buf['pt_taps_rows'][:R] - rows.copy_( - self._dflash_taps_buf[:, :R].permute(1, 0, 2)) - pertoken_window_append(self, rows) + # Must precede the taps[:, 0] shuffle below — it + # reads tap rows 0..N and the shuffle overwrites + # row 0. + self._dflash_window_commit(N) + # Move taps[N] -> taps[0] as the next drafter input + # (N == K on a full accept). + self._dflash_taps_buf[:, 0].copy_( + self._dflash_taps_buf[:, N]) if len(generated) > max_new_tokens: generated = generated[:max_new_tokens] diff --git a/tests/test_qwen36_dflash_structural.py b/tests/test_qwen36_dflash_structural.py new file mode 100644 index 00000000..95338582 --- /dev/null +++ b/tests/test_qwen36_dflash_structural.py @@ -0,0 +1,154 @@ +"""Structural tests for the Qwen3.6 DFlash spec-decode path. + +These run without model checkpoints or a GPU: they validate the +contracts that hardware benchmarks cannot guard cheaply — + + * the per-token window commit reads tap rows 0..N BEFORE the + end-of-cycle taps[:, 0] shuffle overwrites row 0, and stores a + copy (later tap mutation must not alias into the window input); + * the spec-decode loop keeps that ordering (source-order guard); + * generate fails fast with a clear error when no drafter is loaded; + * the public ``init_dflash_drafter`` wrapper delegates to the + loader; + * Thor's per-token window env routing (default on, opt-out, window + length override). + +GPU/end-to-end evidence for this path lives in the hardware-gated +benchmarks; see docs/qwen36_dflash.md. +""" + +from __future__ import annotations + +import inspect + +import pytest + +torch = pytest.importorskip("torch") + +from flash_rt.frontends.torch import _qwen36_rtx_dflash_forward as dff # noqa: E402 +from flash_rt.frontends.torch.qwen36_rtx import ( # noqa: E402 + Qwen36TorchFrontendRtx, +) +from flash_rt.frontends.torch.qwen36_thor import ( # noqa: E402 + Qwen36TorchFrontendThor, +) + + +HIDDEN = 8 +KV = 16 + + +def _stub_rtx(): + fe = Qwen36TorchFrontendRtx.__new__(Qwen36TorchFrontendRtx) + taps = torch.zeros(5, KV, HIDDEN) + for row in range(KV): + taps[:, row] = row + 1 + fe._dflash_taps_buf = taps + fe._dflash_buf = { + "pt_taps_rows": torch.zeros(KV, 5, HIDDEN), + } + return fe + + +def test_window_commit_reads_rows_before_shuffle(monkeypatch): + fe = _stub_rtx() + seen = [] + monkeypatch.setattr( + dff, "pertoken_window_append", + lambda frontend, rows: seen.append(rows)) + + N = 3 + fe._dflash_window_commit(N) + + assert len(seen) == 1 + rows = seen[0] + assert rows.shape == (N + 1, 5, HIDDEN) + # Row order: oldest committed row first, values 1..N+1 per the + # stub filling — row 0 must be the ORIGINAL row 0, not row N. + expect = torch.tensor([1.0, 2.0, 3.0, 4.0]) + assert torch.equal(rows[:, 0, 0], expect) + + # The end-of-cycle shuffle overwrites tap row 0 with row N; the + # committed rows must be a copy, not a view into the tap buffer. + fe._dflash_taps_buf[:, 0].copy_(fe._dflash_taps_buf[:, N]) + assert torch.equal(rows[:, 0, 0], expect) + + +def test_window_commit_full_accept_covers_all_rows(monkeypatch): + fe = _stub_rtx() + seen = [] + monkeypatch.setattr( + dff, "pertoken_window_append", + lambda frontend, rows: seen.append(rows)) + + fe._dflash_window_commit(KV - 1) + assert seen[0].shape[0] == KV + assert torch.equal( + seen[0][:, 0, 0], torch.arange(1.0, KV + 1)) + + +def test_generate_loop_commits_window_before_tap_shuffle(): + src = inspect.getsource( + Qwen36TorchFrontendRtx.generate_own_speculative_DFlash_nvfp4) + commit = src.index("_dflash_window_commit") + shuffle = src.index( + "_dflash_taps_buf[:, 0].copy_", commit) + assert commit < shuffle, ( + "the per-token window must be committed before the taps[:, 0] " + "shuffle overwrites row 0") + + +def test_generate_fails_fast_without_drafter(): + fe = Qwen36TorchFrontendRtx.__new__(Qwen36TorchFrontendRtx) + + class _Weights: + ptrs = {} + + fe._weights = _Weights() + with pytest.raises(RuntimeError, match="DFlash drafter not loaded"): + fe.generate_own_speculative_DFlash_nvfp4( + torch.zeros(1, 4, dtype=torch.long), max_new_tokens=4) + + +def test_public_drafter_init_delegates(monkeypatch): + fe = Qwen36TorchFrontendRtx.__new__(Qwen36TorchFrontendRtx) + calls = [] + monkeypatch.setattr( + Qwen36TorchFrontendRtx, "_load_dflash_drafter", + lambda self, ckpt_dir=None: calls.append(ckpt_dir)) + fe.init_dflash_drafter("/tmp/ckpt") + assert calls == ["/tmp/ckpt"] + + +def _thor_drafter_load(monkeypatch): + """Run Thor's _load_dflash_drafter with the base loader stubbed.""" + monkeypatch.setattr( + Qwen36TorchFrontendRtx, "_load_dflash_drafter", + lambda self, ckpt_dir=None: None) + fe = Qwen36TorchFrontendThor.__new__(Qwen36TorchFrontendThor) + fe._fp8_K_cache = torch.zeros(1) # skip FP8 cache allocation + fe._K_save_max = 16 # skip checkpoint-buffer grow + fe._MAX_PUBLIC_SPEC_K = 15 + fe._load_dflash_drafter() + return fe + + +def test_thor_pertoken_default_on(monkeypatch): + monkeypatch.delenv("FLASHRT_QWEN36_DFLASH_PERTOKEN", raising=False) + monkeypatch.delenv("FLASHRT_QWEN36_DFLASH_WINDOW", raising=False) + fe = _thor_drafter_load(monkeypatch) + assert fe._dflash_pertoken_window is True + assert fe._dflash_pertoken_win == 128 + + +def test_thor_pertoken_env_opt_out(monkeypatch): + monkeypatch.setenv("FLASHRT_QWEN36_DFLASH_PERTOKEN", "0") + fe = _thor_drafter_load(monkeypatch) + assert fe._dflash_pertoken_window is False + + +def test_thor_pertoken_window_env_override(monkeypatch): + monkeypatch.delenv("FLASHRT_QWEN36_DFLASH_PERTOKEN", raising=False) + monkeypatch.setenv("FLASHRT_QWEN36_DFLASH_WINDOW", "64") + fe = _thor_drafter_load(monkeypatch) + assert fe._dflash_pertoken_win == 64 From e84ec90503a82e1d64e1a44fc0a0d495702e04d6 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Fri, 3 Jul 2026 15:44:36 -0400 Subject: [PATCH 07/11] feat(serving): stateless OpenAI-compatible host for the DFlash path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add serving/qwen36_dflash_agent: a request/response serving host for the Qwen3.6-27B DFlash spec-decode path, following the serving-layer contract — policy above the frontend, no session or KV verbs, no exec/ changes. Scope: stateless per request (full prefill each call), batch 1 with serialized requests, greedy decode, /v1/chat/completions + /v1/models + /health, frontend arch auto-detected (SM110 -> Thor, otherwise RTX). Responses carry a flashrt telemetry block with the speculation cycle count, realized accept length, and end-to-end latency. Long-running agent sessions (prefix reuse, tool calling, SSE streaming) remain the domain of serving/qwen36_agent on the MTP path; the README states the split. --- serving/qwen36_dflash_agent/README.md | 90 +++++++++++ serving/qwen36_dflash_agent/__init__.py | 1 + serving/qwen36_dflash_agent/server.py | 202 ++++++++++++++++++++++++ 3 files changed, 293 insertions(+) create mode 100644 serving/qwen36_dflash_agent/README.md create mode 100644 serving/qwen36_dflash_agent/__init__.py create mode 100644 serving/qwen36_dflash_agent/server.py diff --git a/serving/qwen36_dflash_agent/README.md b/serving/qwen36_dflash_agent/README.md new file mode 100644 index 00000000..e7acc9fb --- /dev/null +++ b/serving/qwen36_dflash_agent/README.md @@ -0,0 +1,90 @@ +# serving/qwen36_dflash_agent + +OpenAI-compatible serving host for Qwen3.6-27B NVFP4 with **DFlash +block-diffusion speculative decoding** +(see [`docs/qwen36_dflash.md`](../../docs/qwen36_dflash.md)). + +This directory is the policy layer above the FlashRT execution +contract: it owns request shaping and telemetry only, adds no session +or KV verbs to `exec/`, and keeps the frontend API untouched. + +## Scope + +| | this host | [`serving/qwen36_agent`](../qwen36_agent) | +|---|---|---| +| decode path | DFlash drafter (K=15 block) | MTP chain (K<=6) | +| session state | stateless — full prefill per request | exact-prefix reuse, capsules | +| tool calling / SSE streaming | no | yes | +| concurrency | batch 1, serialized | batch 1, scheduled sessions | + +Use this host for single-stream, short-context request/response +workloads (robot planners, structured-output services) where the +DFlash path measures fastest; use `qwen36_agent` for long-running +agent sessions. + +## Quickstart + +**Prerequisites**: FlashRT built for your GPU (`GPU_ARCH=110` on +Jetson AGX Thor), the Qwen3.6-27B NVFP4 checkpoint, the paired FP8 +MTP checkpoint (frontend construction requires it), and the DFlash +drafter checkpoint: + +```bash +hf download z-lab/Qwen3.6-27B-DFlash --local-dir /models/Qwen3.6-27B-DFlash +pip install fastapi uvicorn +``` + +**1. Start the server** + +```bash +export FLASHRT_QWEN36_MTP_CKPT_DIR=/models/Qwen3.6-27B-FP8 +export FLASHRT_QWEN36_DFLASH_CKPT_DIR=/models/Qwen3.6-27B-DFlash +export FLASHRT_QWEN36_LONG_KV_CACHE=fp8 + +python -m serving.qwen36_dflash_agent.server \ + --checkpoint /models/Qwen3.6-27B-NVFP4 \ + --max-seq 32768 --K 15 \ + --host 127.0.0.1 --port 8000 +``` + +The frontend arch is auto-detected (SM110 -> Thor, otherwise RTX); +override with `--arch thor|rtx`. + +**2. Check it is up** + +```bash +curl -s http://127.0.0.1:8000/health +# {"status":"ok","arch":"thor","path":"dflash","pertoken_window":true,...} +``` + +**3. Chat completion** + +```bash +curl -s http://127.0.0.1:8000/v1/chat/completions \ + -H 'Content-Type: application/json' -d '{ + "model": "qwen3.6-27b-dflash", + "messages": [{"role": "user", "content": + "Output a JSON action list to pick up the red cube and place it on the tray."}], + "max_tokens": 256 + }' +``` + +The response carries a `flashrt` telemetry block with the speculation +cycle count, realized accept length, and end-to-end latency. + +## Limits (v1) + +- Greedy decode only; sampling parameters are accepted and ignored. +- `stream` is not supported; responses return complete. +- The DFlash loop generates the full `max_tokens` budget and the + response is truncated at the first end token — budget generously + but not extravagantly. +- Qwen thinking mode is off by default; pass `"enable_thinking": true` + to opt in. + +## Tuning + +DFlash env knobs (`FLASHRT_QWEN36_DFLASH_PERTOKEN`, `..._WINDOW`, +`..._WINDOW_SEED`) are documented in +[`docs/qwen36_dflash.md`](../../docs/qwen36_dflash.md) together with +measured Thor performance. diff --git a/serving/qwen36_dflash_agent/__init__.py b/serving/qwen36_dflash_agent/__init__.py new file mode 100644 index 00000000..43c5a266 --- /dev/null +++ b/serving/qwen36_dflash_agent/__init__.py @@ -0,0 +1 @@ +"""OpenAI-compatible serving host for Qwen3.6-27B DFlash spec decode.""" diff --git a/serving/qwen36_dflash_agent/server.py b/serving/qwen36_dflash_agent/server.py new file mode 100644 index 00000000..76728be5 --- /dev/null +++ b/serving/qwen36_dflash_agent/server.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +"""FlashRT — Qwen3.6-27B DFlash OpenAI-compatible serving host. + +Serves /v1/chat/completions backed by the DFlash block-diffusion +speculative-decode path (`generate_own_speculative_DFlash_nvfp4`). +This is the policy layer above the FlashRT execution contract: it owns +request shaping and telemetry only, and adds no session or KV verbs. + +Scope (v1): + * Stateless per request — every call prefills the full prompt. + For long-running agent sessions with prefix reuse, tool calling, + and committed-token streaming, use ``serving/qwen36_agent``. + * Batch size 1; concurrent requests are serialized on one GPU. + * Greedy decode only — sampling parameters are accepted and ignored. + * The DFlash loop generates the full ``max_tokens`` budget; the + response is truncated at the first end token during detokenize. + +Usage: + pip install fastapi uvicorn + + export FLASHRT_QWEN36_MTP_CKPT_DIR=/models/Qwen3.6-27B-FP8 + export FLASHRT_QWEN36_DFLASH_CKPT_DIR=/models/Qwen3.6-27B-DFlash + export FLASHRT_QWEN36_LONG_KV_CACHE=fp8 + + python -m serving.qwen36_dflash_agent.server \\ + --checkpoint /models/Qwen3.6-27B-NVFP4 \\ + --max-seq 32768 --K 15 --port 8000 +""" +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import time +import uuid +from typing import Any, Dict, List, Optional + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', +) +log = logging.getLogger('qwen36_dflash_server') + + +def _build_frontend(args): + import torch + + cap = torch.cuda.get_device_capability() + arch = args.arch + if arch == 'auto': + arch = 'thor' if cap == (11, 0) else 'rtx' + if arch == 'thor': + from flash_rt.frontends.torch.qwen36_thor import ( + Qwen36TorchFrontendThor as Frontend, + ) + else: + from flash_rt.frontends.torch.qwen36_rtx import ( + Qwen36TorchFrontendRtx as Frontend, + ) + log.info('loading %s frontend (sm %s), checkpoint=%s', + arch, cap, args.checkpoint) + fe = Frontend(args.checkpoint, quant='nvfp4', max_seq=args.max_seq) + fe.init_dflash_drafter(args.dflash_checkpoint or None) + log.info('DFlash drafter ready (pertoken=%s window=%s)', + getattr(fe, '_dflash_pertoken_window', False), + getattr(fe, '_dflash_pertoken_win', None)) + return fe, arch + + +def _chat_ids(fe, messages: List[Dict[str, Any]], enable_thinking: bool): + return fe._tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=enable_thinking, + return_tensors='pt', + ).to(fe.device) + + +def create_app(args): + from fastapi import FastAPI, HTTPException + + fe, arch = _build_frontend(args) + tok = fe._tokenizer + end_ids = {tid for tid in ( + tok.eos_token_id, + tok.convert_tokens_to_ids('<|im_end|>'), + ) if isinstance(tid, int) and tid >= 0} + + app = FastAPI(title='FlashRT Qwen3.6 DFlash server') + gpu_lock = asyncio.Lock() + state = {'requests': 0} + + @app.get('/health') + async def health(): + return { + 'status': 'ok', + 'arch': arch, + 'path': 'dflash', + 'max_seq': args.max_seq, + 'K': args.K, + 'pertoken_window': bool( + getattr(fe, '_dflash_pertoken_window', False)), + 'window': getattr(fe, '_dflash_pertoken_win', None), + 'requests_served': state['requests'], + } + + @app.get('/v1/models') + async def models(): + return {'object': 'list', 'data': [{ + 'id': args.model_name, 'object': 'model', + 'owned_by': 'flashrt'}]} + + @app.post('/v1/chat/completions') + async def chat(body: Dict[str, Any]): + import torch + + messages = body.get('messages') + if not messages: + raise HTTPException(400, 'messages is required') + max_tokens = int(body.get('max_tokens') or args.default_max_tokens) + max_tokens = max(1, min(max_tokens, args.max_tokens_cap)) + enable_thinking = bool(body.get('enable_thinking', False)) + + async with gpu_lock: + t0 = time.perf_counter() + ids = _chat_ids(fe, messages, enable_thinking) + prompt_len = int(ids.shape[1]) + if prompt_len + max_tokens > args.max_seq: + raise HTTPException( + 400, f'prompt ({prompt_len}) + max_tokens ' + f'({max_tokens}) exceeds max_seq ({args.max_seq})') + out = await asyncio.to_thread( + fe.generate_own_speculative_DFlash_nvfp4, + ids, max_new_tokens=max_tokens, K=args.K) + torch.cuda.synchronize() + dt = time.perf_counter() - t0 + + new_ids = out[0, prompt_len:].tolist() + for i, t in enumerate(new_ids): + if t in end_ids: + new_ids = new_ids[:i] + break + text = tok.decode(new_ids, skip_special_tokens=True) + attempts = int(getattr(fe, '_spec_attempts', 0)) + state['requests'] += 1 + return { + 'id': f'chatcmpl-{uuid.uuid4().hex[:24]}', + 'object': 'chat.completion', + 'created': int(time.time()), + 'model': args.model_name, + 'choices': [{ + 'index': 0, + 'message': {'role': 'assistant', 'content': text}, + 'finish_reason': ( + 'stop' if len(new_ids) < max_tokens else 'length'), + }], + 'usage': { + 'prompt_tokens': prompt_len, + 'completion_tokens': len(new_ids), + 'total_tokens': prompt_len + len(new_ids), + }, + 'flashrt': { + 'path': 'dflash', + 'spec_cycles': attempts, + 'accept_length': ( + round(len(new_ids) / attempts, 2) if attempts else None), + 'e2e_ms': round(dt * 1e3, 1), + }, + } + + return app + + +def main() -> int: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('--checkpoint', required=True, + help='Qwen3.6-27B NVFP4 checkpoint directory') + p.add_argument('--dflash-checkpoint', default='', + help='DFlash drafter directory (default: ' + 'FLASHRT_QWEN36_DFLASH_CKPT_DIR)') + p.add_argument('--model-name', default='qwen3.6-27b-dflash') + p.add_argument('--arch', choices=['auto', 'thor', 'rtx'], + default='auto') + p.add_argument('--max-seq', type=int, default=32768) + p.add_argument('--K', type=int, default=15, + help='speculative tokens per cycle (block_size - 1)') + p.add_argument('--default-max-tokens', type=int, default=256) + p.add_argument('--max-tokens-cap', type=int, default=4096) + p.add_argument('--host', default='127.0.0.1') + p.add_argument('--port', type=int, default=8000) + args = p.parse_args() + + os.environ.setdefault('FLASHRT_QWEN36_LONG_KV_CACHE', 'fp8') + + import uvicorn + uvicorn.run(create_app(args), host=args.host, port=args.port) + return 0 + + +if __name__ == '__main__': + raise SystemExit(main()) From d88b86cd67f5273714ef587bbce6eea0dc19851a Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Sat, 4 Jul 2026 06:08:04 -0400 Subject: [PATCH 08/11] feat(qwen36): relaxed thinking-phase acceptance for DFlash (opt-in) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Qwen3.6 reasons inside a block before answering, and the thinking stream dominates the token budget of short-context requests. Mirroring the TensorRT-LLM MTP relaxed-acceptance policy: inside the think block a draft is accepted when it is in the verify logits' top-k AND within a logit margin of the argmax (a raw-logit margin equals a log-prob margin), and the accepted token is the draft itself — the verify rows and per-step state already condition on the drafts, so state and KV stay consistent. Rows from the first draft that closes the think block fall back to strict argmax matching, keeping the visible answer exact-verified. Opt-in via FLASHRT_QWEN36_DFLASH_RELAXED_THINKING (default off; the strict path is byte-identical with it disabled, re-verified by greedy parity against the MTP reference). Knobs: FLASHRT_QWEN36_DFLASH_RELAXED_TOPK (3), FLASHRT_QWEN36_DFLASH_RELAXED_DELTA (1.0). The acceptance math lives in a static helper with CPU unit tests (top-k membership, margin cutoff, strict-after-close). Measured on Thor (thinking-enabled robot JSON-plan prompt, steady state): AL 3.78 -> 5.42, 40.4 -> 57.7 tok/s (+43%); a math prompt whose drafts rarely reach the top-k measured neutral. The thinking transcript is no longer token-identical to the strict run. --- flash_rt/frontends/torch/qwen36_rtx.py | 85 +++++++++++++++++++++++++- tests/test_qwen36_dflash_structural.py | 44 +++++++++++++ 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/flash_rt/frontends/torch/qwen36_rtx.py b/flash_rt/frontends/torch/qwen36_rtx.py index 1f5cd91f..c129c9db 100644 --- a/flash_rt/frontends/torch/qwen36_rtx.py +++ b/flash_rt/frontends/torch/qwen36_rtx.py @@ -11772,6 +11772,33 @@ def _dflash_prefill_nvfp4(self, input_ids): self._replay_pos_graph(g_pf, p) return self._logits_buf.argmax(dim=-1, keepdim=True).view(1, 1) + @staticmethod + def _dflash_relaxed_matches(logits_K, drafts, all_argmax, + topk: int, delta: float, close_id: int): + """Relaxed draft acceptance for the thinking phase. + + A draft row is accepted when its token is inside the verify + logits' top-``topk`` AND within ``delta`` of the argmax logit + (a raw-logit margin equals a log-prob margin). Rows from the + first draft that closes the think block fall back to strict + argmax matching so the visible answer stays exact-verified. + Returns a 0/1 tensor of shape (K,). + """ + import torch + + K = int(drafts.shape[0]) + topv, topi = torch.topk(logits_K, topk, dim=-1) + ok = ( + (topi == drafts.view(K, 1)) + & ((topv[:, :1] - topv) <= delta) + ).any(-1).long() + close_mask = drafts == close_id + if bool(close_mask.any().item()): + idx = int(close_mask.nonzero()[0].item()) + strict = (all_argmax[:K] == drafts).long() + ok[idx:] = strict[idx:] + return ok + def _dflash_window_commit(self, N: int) -> None: """Append the committed rows' features to the per-token window. @@ -12002,6 +12029,34 @@ def generate_own_speculative_DFlash_nvfp4( alloc_pertoken_window( self, int(getattr(self, '_dflash_pertoken_win', 128))) reset_pertoken_window(self) + # Relaxed acceptance for the thinking phase (opt-in; mirrors + # the TensorRT-LLM MTP policy): inside a block a draft + # is accepted when it is in the verify logits' top-k AND within + # a logit margin of the argmax; the accepted token is then the + # DRAFT (rows already condition on drafts, so state/KV stay + # consistent). Rows from the first draft that closes the think + # block fall back to strict argmax matching. Default off — the + # strict path is byte-identical with this disabled. + relaxed = None + if os.environ.get( + 'FLASHRT_QWEN36_DFLASH_RELAXED_THINKING', '0', + ).strip().lower() in ('1', 'true', 'on'): + think_open = self._tokenizer.convert_tokens_to_ids('') + think_close = self._tokenizer.convert_tokens_to_ids('') + if isinstance(think_open, int) and think_open >= 0: + relaxed = { + 'topk': max(1, int(os.environ.get( + 'FLASHRT_QWEN36_DFLASH_RELAXED_TOPK', '3'))), + 'delta': float(os.environ.get( + 'FLASHRT_QWEN36_DFLASH_RELAXED_DELTA', '1.0')), + 'open': int(think_open), + 'close': int(think_close), + } + # The chat template opens the think block at the end of the + # generation prompt, so the phase can start active. + in_think = bool( + relaxed is not None + and relaxed['open'] in input_ids[0, -8:].tolist()) # Initialize taps to zero — first drafter call gets no real # signal; AL on cycle 0 will be lower than steady-state. self._dflash_taps_buf.zero_() @@ -12094,7 +12149,14 @@ def generate_own_speculative_DFlash_nvfp4( # 2d) Argmax + accept-prefix all_argmax = logits_KN.argmax(dim=-1) # (Kv,) long - matches = (all_argmax[:K] == drafts).long() + relaxed_cycle = relaxed is not None and in_think + if relaxed_cycle: + matches = self._dflash_relaxed_matches( + logits_KN[:K], drafts, all_argmax, + relaxed['topk'], relaxed['delta'], + relaxed['close']) + else: + matches = (all_argmax[:K] == drafts).long() matches_pad = torch.cat([ matches, torch.zeros(1, device=matches.device, @@ -12105,22 +12167,39 @@ def generate_own_speculative_DFlash_nvfp4( self._spec_accepts += N argmax_at = (lambda j: all_argmax[j:j + 1].view(1, 1)) + if relaxed_cycle: + # Accepted rows commit the DRAFT token (the verify + # rows and per-step state condition on the drafts); + # the bonus row commits the argmax as usual. + commit_at = (lambda j: ( + drafts[j:j + 1].view(1, 1) if j < N + else argmax_at(j))) + else: + commit_at = argmax_at if N == K: self._spec_full += 1 for j in range(Kv): if len(generated) < max_new_tokens: - generated.append(argmax_at(j)) + generated.append(commit_at(j)) tok = argmax_at(K) cur_pos += Kv else: for j in range(N + 1): if len(generated) < max_new_tokens: - generated.append(argmax_at(j)) + generated.append(commit_at(j)) self._dflash_partial_rollback( cur_pos, N, Kv, tok, drafts, cos_KN, sin_KN) tok = argmax_at(N) cur_pos += N + 1 + if relaxed is not None: + ids = (drafts[:N].tolist() if N else []) + ids.append(int(all_argmax[N].item())) + for t in ids: + if t == relaxed['open']: + in_think = True + elif t == relaxed['close']: + in_think = False if pertoken: # Must precede the taps[:, 0] shuffle below — it # reads tap rows 0..N and the shuffle overwrites diff --git a/tests/test_qwen36_dflash_structural.py b/tests/test_qwen36_dflash_structural.py index 95338582..aef327aa 100644 --- a/tests/test_qwen36_dflash_structural.py +++ b/tests/test_qwen36_dflash_structural.py @@ -152,3 +152,47 @@ def test_thor_pertoken_window_env_override(monkeypatch): monkeypatch.setenv("FLASHRT_QWEN36_DFLASH_WINDOW", "64") fe = _thor_drafter_load(monkeypatch) assert fe._dflash_pertoken_win == 64 + + +def _relaxed(logits, drafts, topk=3, delta=1.0, close_id=99): + all_argmax = logits.argmax(dim=-1) + return Qwen36TorchFrontendRtx._dflash_relaxed_matches( + logits, drafts, all_argmax, topk, delta, close_id) + + +def test_relaxed_accepts_topk_within_margin(): + # row 0: draft is argmax; row 1: draft is 2nd-best inside margin; + # row 2: draft is 2nd-best OUTSIDE margin; row 3: draft not in topk + logits = torch.tensor([ + [5.0, 1.0, 0.0, 0.0], + [5.0, 4.5, 0.0, 0.0], + [5.0, 2.0, 0.0, 0.0], + [5.0, 4.9, 4.8, 4.7], + ]) + drafts = torch.tensor([0, 1, 1, 3]) + ok = _relaxed(logits, drafts, topk=3, delta=1.0) + assert ok.tolist() == [1, 1, 0, 0] + + +def test_relaxed_strict_after_think_close(): + # row 1 closes the think block -> rows 1+ require exact argmax + logits = torch.tensor([ + [5.0, 4.5, 0.0, 0.0], + [5.0, 4.9, 0.0, 0.0], + [5.0, 4.9, 0.0, 0.0], + ]) + drafts = torch.tensor([1, 2, 1]) # draft row 1 is close_id=2 + ok = _relaxed(logits, drafts, topk=3, delta=1.0, close_id=2) + # row 0 relaxed-accepted; row 1 (close) strict: argmax=0 != 2 -> 0; + # row 2 strict: argmax=0 != 1 -> 0 + assert ok.tolist() == [1, 0, 0] + + +def test_relaxed_strict_rows_match_argmax(): + logits = torch.tensor([ + [5.0, 4.5, 0.0], + [1.0, 6.0, 0.0], + ]) + drafts = torch.tensor([2, 1]) # row 0 closes -> strict from row 0 + ok = _relaxed(logits, drafts, topk=3, delta=10.0, close_id=2) + assert ok.tolist() == [0, 1] From a6562842edea7c34ceb63f44436eaf894801ae53 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Sat, 4 Jul 2026 06:08:04 -0400 Subject: [PATCH 09/11] perf(qwen36): opt-in chunk-saves verify kernels for Thor DFlash MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add per-step-checkpoint variants of the Qwen3.6 chunk kernels: causal_conv1d_qwen36_update_chunk_saves_bf16 and qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16 dump the post-step state after every step (the conv window is bf16-exact in registers; the GDN kernel already rounds the carried state to bf16 each step, so slot s byte-matches the committed state of an S = s + 1 run). This serves the DFlash partial-accept rollback in one pass instead of the parent branch's per-step kernels with global state round-trips: verify 84.1 -> 79.5 ms, cycle 93.3 -> 88.8 ms (~+5% decode) at unchanged acceptance behavior. Default OFF (FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES=1 to enable): the route moves the S=8..16 verify onto the Thor kernel family while the greedy-parity reference (the MTP spec path) runs the parent family, and the families' occasional rounding disagreements surface as transcript-level divergence — measured on the repeated-sentence parity prompt. K <= 7 dispatch (the production MTP verify) is untouched either way. --- csrc/bindings.cpp | 44 ++++++ csrc/kernels/causal_conv1d_qwen36.cu | 99 ++++++++++++++ csrc/kernels/causal_conv1d_qwen36.cuh | 15 +++ csrc/kernels/gated_deltanet_qwen36.cu | 172 ++++++++++++++++++++++++ csrc/kernels/gated_deltanet_qwen36.cuh | 20 +++ flash_rt/frontends/torch/qwen36_thor.py | 118 +++++++++++----- 6 files changed, 438 insertions(+), 30 deletions(-) diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 86167975..9f09c598 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -4157,6 +4157,25 @@ PYBIND11_MODULE(flash_rt_kernels, m) { py::arg("B"), py::arg("S"), py::arg("conv_dim"), py::arg("k"), py::arg("apply_silu") = true, py::arg("stream") = 0); + m.def("causal_conv1d_qwen36_update_chunk_saves_bf16", + [](uintptr_t x, uintptr_t w, uintptr_t bias, + uintptr_t out, uintptr_t state, + uintptr_t state_steps, int64_t step_stride, + int B, int S, int conv_dim, int k, bool apply_silu, + uintptr_t stream) { + flash_rt::kernels::causal_conv1d_qwen36_update_chunk_saves_bf16( + to_ptr(x), to_ptr(w), + bias ? to_ptr(bias) : nullptr, + to_ptr(out), to_ptr(state), + to_ptr(state_steps), step_stride, + B, S, conv_dim, k, apply_silu, to_stream(stream)); + }, + py::arg("x"), py::arg("w"), py::arg("bias"), + py::arg("out"), py::arg("state"), + py::arg("state_steps"), py::arg("step_stride"), + py::arg("B"), py::arg("S"), py::arg("conv_dim"), py::arg("k"), + py::arg("apply_silu") = true, py::arg("stream") = 0); + m.def("causal_conv1d_qwen36_update_chunk_parallel_bf16", [](uintptr_t x, uintptr_t w, uintptr_t bias, uintptr_t out, uintptr_t state, @@ -4876,6 +4895,31 @@ PYBIND11_MODULE(flash_rt_kernels, m) { py::arg("a_stride"), py::arg("b_stride"), py::arg("use_qk_l2norm") = true, py::arg("stream") = 0); + m.def("qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16", + [](uintptr_t conv_out, uintptr_t a, uintptr_t b, + uintptr_t neg_exp_A_log, uintptr_t dt_bias, + uintptr_t state, uintptr_t state_steps, int64_t step_stride, + uintptr_t out, + int S, int num_v_heads, int a_stride, int b_stride, + bool use_qk_l2norm, uintptr_t stream) { + flash_rt::kernels:: + qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + to_ptr(conv_out), to_ptr(a), to_ptr(b), + reinterpret_cast(neg_exp_A_log), + reinterpret_cast(dt_bias), + to_ptr(state), to_ptr(state_steps), step_stride, + to_ptr(out), + S, num_v_heads, a_stride, b_stride, + use_qk_l2norm, to_stream(stream)); + }, + py::arg("conv_out"), py::arg("a"), py::arg("b"), + py::arg("neg_exp_A_log"), py::arg("dt_bias"), + py::arg("state"), py::arg("state_steps"), + py::arg("step_stride"), py::arg("out"), + py::arg("S"), py::arg("num_v_heads"), + py::arg("a_stride"), py::arg("b_stride"), + py::arg("use_qk_l2norm") = true, py::arg("stream") = 0); + m.def("qwen36_gdn_wy_norm_cumsum_bf16", [](uintptr_t q16, uintptr_t k16, uintptr_t g, uintptr_t q16_l2, uintptr_t k16_l2, uintptr_t g_cumsum, diff --git a/csrc/kernels/causal_conv1d_qwen36.cu b/csrc/kernels/causal_conv1d_qwen36.cu index f4e1c698..7aa15f06 100644 --- a/csrc/kernels/causal_conv1d_qwen36.cu +++ b/csrc/kernels/causal_conv1d_qwen36.cu @@ -193,6 +193,84 @@ __global__ void causal_conv1d_update_chunk_kernel( } } +// Per-step-checkpoint variant of the chunk kernel above: identical +// math (the carried window values are bf16-exact in fp32 registers), +// plus a bf16 dump of the post-shift state after every step into +// ``state_steps`` (step s at state_steps + s * step_stride). Slot s +// byte-matches the committed state of an S = s + 1 run, which is what +// the spec-decode partial-accept rollback copies. +__global__ void causal_conv1d_update_chunk_saves_kernel( + const __nv_bfloat16* __restrict__ x, + const __nv_bfloat16* __restrict__ w, + const __nv_bfloat16* __restrict__ bias, + __nv_bfloat16* __restrict__ out, + __nv_bfloat16* __restrict__ state, + __nv_bfloat16* __restrict__ state_steps, + int64_t step_stride, + int B, int S, int conv_dim, int k, + bool apply_silu) +{ + const int c = blockIdx.x * kThreadsX + threadIdx.x; + const int b = blockIdx.y; + if (c >= conv_dim) return; + + const int sk = k - 1; + const int state_base = (b * conv_dim + c) * sk; + + float wv[kMaxK]; + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + wv[i] = (i < k) ? static_cast(w[c * k + i]) : 0.0f; + } + + float sv[kMaxK]; + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + sv[i] = (i < sk) + ? static_cast(state[state_base + i]) + : 0.0f; + } + + for (int s = 0; s < S; ++s) { + const float x_v = static_cast( + x[(size_t)b * S * conv_dim + (size_t)s * conv_dim + c]); + + float acc = (bias != nullptr) ? static_cast(bias[c]) : 0.0f; + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + if (i < sk) acc = fmaf(sv[i], wv[i], acc); + } + acc = fmaf(x_v, wv[sk], acc); + + if (apply_silu) acc = silu(acc); + out[(size_t)b * S * conv_dim + (size_t)s * conv_dim + c] = + __float2bfloat16(acc); + + #pragma unroll + for (int i = 0; i < kMaxK - 1; ++i) { + if (i < sk - 1) sv[i] = sv[i + 1]; + } + if (sk >= 1) { + sv[sk - 1] = x_v; + } + + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + if (i < sk) { + state_steps[(size_t)s * step_stride + state_base + i] = + __float2bfloat16(sv[i]); + } + } + } + + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + if (i < sk) { + state[state_base + i] = __float2bfloat16(sv[i]); + } + } +} + __global__ void causal_conv1d_update_chunk_parallel_kernel( const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ w, @@ -360,6 +438,27 @@ void causal_conv1d_qwen36_update_chunk_bf16( B, S, conv_dim, k, apply_silu); } +void causal_conv1d_qwen36_update_chunk_saves_bf16( + const void* x, const void* w, const void* bias, + void* out, void* state, + void* state_steps, int64_t step_stride, + int B, int S, int conv_dim, int k, + bool apply_silu, + cudaStream_t stream) +{ + dim3 grid((conv_dim + kThreadsX - 1) / kThreadsX, B); + dim3 block(kThreadsX); + causal_conv1d_update_chunk_saves_kernel<<>>( + reinterpret_cast(x), + reinterpret_cast(w), + reinterpret_cast(bias), + reinterpret_cast<__nv_bfloat16*>(out), + reinterpret_cast<__nv_bfloat16*>(state), + reinterpret_cast<__nv_bfloat16*>(state_steps), + step_stride, + B, S, conv_dim, k, apply_silu); +} + void causal_conv1d_qwen36_update_chunk_parallel_bf16( const void* x, const void* w, const void* bias, void* out, void* state, diff --git a/csrc/kernels/causal_conv1d_qwen36.cuh b/csrc/kernels/causal_conv1d_qwen36.cuh index e0e22432..fa6e9f1f 100644 --- a/csrc/kernels/causal_conv1d_qwen36.cuh +++ b/csrc/kernels/causal_conv1d_qwen36.cuh @@ -90,6 +90,21 @@ void causal_conv1d_qwen36_update_chunk_bf16( bool apply_silu, cudaStream_t stream); +// Chunk variant with per-step state checkpoints: dumps the post-step +// conv state to state_steps + s * step_stride for every step s, for +// the spec-decode partial-accept rollback. +void causal_conv1d_qwen36_update_chunk_saves_bf16( + const void* x, + const void* w, + const void* bias, + void* out, + void* state, + void* state_steps, + int64_t step_stride, + int B, int S, int conv_dim, int k, + bool apply_silu, + cudaStream_t stream); + // Parallel prefill variant: computes each (S, channel) output // independently, then updates the final state in a second tiny kernel. // This trades extra global loads for much higher S-dimension diff --git a/csrc/kernels/gated_deltanet_qwen36.cu b/csrc/kernels/gated_deltanet_qwen36.cu index c875e2c9..c80754c7 100644 --- a/csrc/kernels/gated_deltanet_qwen36.cu +++ b/csrc/kernels/gated_deltanet_qwen36.cu @@ -896,6 +896,134 @@ __global__ void qwen36_gdn_chunk_from_conv_smem_kernel( } } +// Per-step-checkpoint variant of the chunk kernel above: identical +// math and rounding cadence (the state is rounded to bf16 after every +// step exactly as the original does between steps), plus a dump of +// each step's rounded state into ``state_steps`` (step s at +// state_steps + s * step_stride). Slot s byte-matches the committed +// state of an S = s + 1 run, which is what the spec-decode +// partial-accept rollback copies. +template +__global__ void qwen36_gdn_chunk_from_conv_smem_saves_kernel( + const __nv_bfloat16* __restrict__ conv_out, + const __nv_bfloat16* __restrict__ a_in, + const __nv_bfloat16* __restrict__ b_in, + const float* __restrict__ neg_exp_A_log, + const float* __restrict__ dt_bias, + __nv_bfloat16* __restrict__ state, + __nv_bfloat16* __restrict__ state_steps, + int64_t step_stride, + __nv_bfloat16* __restrict__ out_, + int S, + int num_v_heads, + int a_stride, + int b_stride, + bool use_qk_l2norm) +{ + static_assert(HD == 128, "HD must be 128 for Qwen3.6"); + const int h = blockIdx.x; + const int b = blockIdx.y; + const int t = threadIdx.x; + if (t >= HD) return; + + extern __shared__ float smem[]; + float* state_s = smem; + float* qs = state_s + HD * HD; + float* ks = qs + HD; + float* scratch = ks + HD; + + const size_t state_h_off = + (((size_t)b * num_v_heads + h)) * HD * HD; + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state_s[i * HD + t] = static_cast( + state[state_h_off + (size_t)i * HD + t]); + } + __syncthreads(); + + const int src_h = h / 3; + for (int s = 0; s < S; ++s) { + const size_t row = static_cast(s) * 10240; + const size_t out_off = ((size_t)s * num_v_heads + h) * HD + t; + qs[t] = static_cast(conv_out[row + src_h * HD + t]); + ks[t] = static_cast(conv_out[row + 2048 + src_h * HD + t]); + __syncthreads(); + + if (use_qk_l2norm) { + float q_sq = qs[t] * qs[t]; + float k_sq = ks[t] * ks[t]; + q_sq = block_reduce_sum(q_sq, scratch); + // See the non-saves kernel for why this barrier is required + // between the two block reductions sharing ``scratch``. + __syncthreads(); + k_sq = block_reduce_sum(k_sq, scratch); + const float q_inv = rsqrtf(q_sq + kEps); + const float k_inv = rsqrtf(k_sq + kEps); + qs[t] *= q_inv; + ks[t] *= k_inv; + __syncthreads(); + } + + qs[t] *= rsqrtf(static_cast(HD)); + __syncthreads(); + + const float av = + static_cast(a_in[s * a_stride + h]) + dt_bias[h]; + const float sp = log1pf(__expf(av)); + const float g_log = static_cast( + __float2bfloat16(neg_exp_A_log[h] * sp)); + const float g_t = __expf(g_log); + const float bv = static_cast(b_in[s * b_stride + h]); + const float beta_t = static_cast( + __float2bfloat16(1.0f / (1.0f + __expf(-bv)))); + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state_s[i * HD + t] *= g_t; + } + + float kv_mem = 0.0f; + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + kv_mem = fmaf(state_s[i * HD + t], ks[i], kv_mem); + } + + const float v_t = + static_cast(conv_out[row + 4096 + h * HD + t]); + const float delta = (v_t - kv_mem) * beta_t; + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state_s[i * HD + t] = + fmaf(ks[i], delta, state_s[i * HD + t]); + } + + float out_t = 0.0f; + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + out_t = fmaf(state_s[i * HD + t], qs[i], out_t); + } + out_[out_off] = __float2bfloat16(out_t); + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + const __nv_bfloat16 v = + __float2bfloat16(state_s[i * HD + t]); + state_steps[ + (size_t)s * step_stride + state_h_off + (size_t)i * HD + t] = + v; + state_s[i * HD + t] = static_cast(v); + } + __syncthreads(); + } + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state[state_h_off + (size_t)i * HD + t] = + __float2bfloat16(state_s[i * HD + t]); + } +} + __global__ void qwen36_gdn_wy_norm_qk_kernel( const __nv_bfloat16* __restrict__ q16, const __nv_bfloat16* __restrict__ k16, @@ -1469,6 +1597,50 @@ void qwen36_gdn_chunk_from_conv_smem_strided_bf16( S, num_v_heads, a_stride, b_stride, use_qk_l2norm); } +void qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + const void* conv_out, + const void* a, + const void* b, + const float* neg_exp_A_log, + const float* dt_bias, + void* state, + void* state_steps, + int64_t step_stride, + void* out, + int S, + int num_v_heads, + int a_stride, + int b_stride, + bool use_qk_l2norm, + cudaStream_t stream) +{ + if (S <= 0 || num_v_heads <= 0) return; + dim3 grid(num_v_heads, 1); + dim3 block(kHD); + constexpr size_t kSmemBytes = + (kHD * kHD + 2 * kHD + 32) * sizeof(float); + static bool attr_set = false; + if (!attr_set) { + cudaFuncSetAttribute( + qwen36_gdn_chunk_from_conv_smem_saves_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(kSmemBytes)); + attr_set = true; + } + qwen36_gdn_chunk_from_conv_smem_saves_kernel<<< + grid, block, kSmemBytes, stream>>>( + reinterpret_cast(conv_out), + reinterpret_cast(a), + reinterpret_cast(b), + neg_exp_A_log, + dt_bias, + reinterpret_cast<__nv_bfloat16*>(state), + reinterpret_cast<__nv_bfloat16*>(state_steps), + step_stride, + reinterpret_cast<__nv_bfloat16*>(out), + S, num_v_heads, a_stride, b_stride, use_qk_l2norm); +} + void gated_deltanet_chunk_smem_qwen36_bf16( const void* q, const void* k, diff --git a/csrc/kernels/gated_deltanet_qwen36.cuh b/csrc/kernels/gated_deltanet_qwen36.cuh index d53de34b..f497eff9 100644 --- a/csrc/kernels/gated_deltanet_qwen36.cuh +++ b/csrc/kernels/gated_deltanet_qwen36.cuh @@ -216,6 +216,26 @@ void qwen36_gdn_chunk_from_conv_smem_strided_bf16( bool use_qk_l2norm, cudaStream_t stream); +// Chunk variant with per-step state checkpoints: dumps the post-step +// (bf16-rounded) recurrent state to state_steps + s * step_stride for +// every step s, for the spec-decode partial-accept rollback. +void qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + const void* conv_out, + const void* a, + const void* b, + const float* neg_exp_A_log, + const float* dt_bias, + void* state, + void* state_steps, + int64_t step_stride, + void* out, + int S, + int num_v_heads, + int a_stride, + int b_stride, + bool use_qk_l2norm, + cudaStream_t stream); + // Chunk/WY Gated DeltaNet building blocks. These are the native // FlashRT replacement path for the Python/Triton FLA chunk implementation. // First specialization targets Qwen3.6 shapes: diff --git a/flash_rt/frontends/torch/qwen36_thor.py b/flash_rt/frontends/torch/qwen36_thor.py index 586f24b3..c07800aa 100644 --- a/flash_rt/frontends/torch/qwen36_thor.py +++ b/flash_rt/frontends/torch/qwen36_thor.py @@ -206,28 +206,56 @@ def _thor_alloc_K_row_scratch(self) -> None: # per-position sub-loop. Bit-exact to running K sequential single- # token forwards (see DESIGN §4.5 for the leaf-kernel set). def _layer_forward_lin_K_nvfp4(self, L, h_in_K, K): - # Delegate the whole save-steps range to parent, not just the - # fast-path max: parent's per-step recurrent branch is active - # for K <= _K_save_max, stays per-token-equivalent on SM110, - # and writes the per-step state checkpoints the DFlash - # partial-accept rollback reads (_K_save_max is grown to the - # DFlash verify q_seq at drafter load). - if K <= max(self._THOR_K_ROW_FAST_PATH_MAX, self._K_save_max): + # K <= 7 stays on parent's per-step branch — the production + # MTP spec verify path, untouched. The 8..16 band (DFlash + # verify) defaults to parent as well: greedy parity against + # the MTP reference is anchored to parent-family rounding, and + # a Thor-family verify measurably drifts from it. The opt-in + # chunk-saves route (FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES=1) + # trades that token-exact parity for ~5% lower verify cost + # (chunk kernels + per-step checkpoints in one pass) — for + # deployments gating on task-level quality instead. + if K <= self._THOR_K_ROW_FAST_PATH_MAX: + return super()._layer_forward_lin_K_nvfp4(L, h_in_K, K) + if K <= self._K_save_max: + if self._thor_lin_chunk_saves_enabled(): + return self._thor_lin_K_forward(L, h_in_K, K) return super()._layer_forward_lin_K_nvfp4(L, h_in_K, K) if K > self.MAX_Q_SEQ: return self._thor_lin_K_dispatch(L, h_in_K, K) return self._thor_lin_K_forward(L, h_in_K, K) + def _thor_lin_chunk_saves_enabled(self) -> bool: + cached = getattr(self, '_thor_lin_saves_flag', None) + if cached is None: + from flash_rt import flash_rt_kernels as fvk + + cached = ( + hasattr(fvk, 'causal_conv1d_qwen36_update_chunk_saves_bf16') + and hasattr( + fvk, + 'qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16') + and os.environ.get( + 'FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES', '0', + ).strip().lower() in ('1', 'true', 'on')) + self._thor_lin_saves_flag = cached + return cached + def _layer_forward_full_K_nvfp4( self, L, h_in_K, cos_K, sin_K, cur_pos, K): - # Delegate the whole save-steps range to parent, mirroring the - # lin dispatch above. The DFlash spec loop commits per-row - # state/KV from its S=16 verify (slot-copy rollback); rows must - # therefore come from the SAME kernel family as the K<=7 spec - # verifies, or the two paths' occasional rounding disagreements - # surface as greedy divergence (measured: bit-identical for ~39 - # cycles, then a single full-attn row event cascades). - if K <= max(self._THOR_K_ROW_FAST_PATH_MAX, self._K_save_max): + # The verify must stay on ONE kernel family end to end: rows + # committed by one family while other rows (or the rollback + # checkpoints) come from another surface the families' + # occasional rounding disagreements as greedy divergence. + # K <= 7 (the production MTP verify) stays on parent. The + # 8..16 band follows the lin dispatch: Thor from-scratch when + # the chunk-saves kernels serve the lin layers, parent + # otherwise — mixing families across layer types measurably + # breaks greedy parity. + if K <= self._THOR_K_ROW_FAST_PATH_MAX: + return super()._layer_forward_full_K_nvfp4( + L, h_in_K, cos_K, sin_K, cur_pos, K) + if K <= self._K_save_max and not self._thor_lin_chunk_saves_enabled(): return super()._layer_forward_full_K_nvfp4( L, h_in_K, cos_K, sin_K, cur_pos, K) if K > self.MAX_Q_SEQ: @@ -315,12 +343,28 @@ def _thor_lin_K_forward(self, L, h_in_K, K): lin_rank = self._linear_layer_rank(L) conv_state = self._lin_conv_state[lin_rank] conv_out_K = self._K_lin_conv_out[:K] - fvk.causal_conv1d_qwen36_update_chunk_bf16( - out_qkv_K.data_ptr(), int(lw['conv1d_w']), - int(lw['conv1d_b']), - conv_out_K.data_ptr(), conv_state.data_ptr(), - 1, K, 10240, 4, True, s, - ) + # Inside the save-steps range, dump per-step state checkpoints + # for the spec-decode partial-accept rollback (same slots the + # parent per-step branch writes). + save_steps = ( + K <= self._K_save_max and self._thor_lin_chunk_saves_enabled()) + if save_steps: + conv_steps = self._K_lin_conv_state_per_step + fvk.causal_conv1d_qwen36_update_chunk_saves_bf16( + out_qkv_K.data_ptr(), int(lw['conv1d_w']), + int(lw['conv1d_b']), + conv_out_K.data_ptr(), conv_state.data_ptr(), + conv_steps[0, lin_rank].data_ptr(), + conv_steps.stride(0), + 1, K, 10240, 4, True, s, + ) + else: + fvk.causal_conv1d_qwen36_update_chunk_bf16( + out_qkv_K.data_ptr(), int(lw['conv1d_w']), + int(lw['conv1d_b']), + conv_out_K.data_ptr(), conv_state.data_ptr(), + 1, K, 10240, 4, True, s, + ) # (7-9) Fused conv_out -> split + Q/K broadcast + GDN gating # + GDN chunk recurrent in one launch. Replaces three separate @@ -329,15 +373,29 @@ def _thor_lin_K_forward(self, L, h_in_K, K): attn_out_K = self._K_lin_attn_out[:K] a_stride = a_vec_K.stride(0) b_stride = b_vec_K.stride(0) - fvk.qwen36_gdn_chunk_from_conv_smem_strided_bf16( - conv_out_K.data_ptr(), - a_vec_K.data_ptr(), b_vec_K.data_ptr(), - lw['neg_A_log_exp_fp32_t'].data_ptr(), - lw['dt_bias_fp32_t'].data_ptr(), - rec_state.data_ptr(), - attn_out_K.data_ptr(), - K, 48, a_stride, b_stride, True, s, - ) + if save_steps: + lin_steps = self._K_lin_state_per_step + fvk.qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + conv_out_K.data_ptr(), + a_vec_K.data_ptr(), b_vec_K.data_ptr(), + lw['neg_A_log_exp_fp32_t'].data_ptr(), + lw['dt_bias_fp32_t'].data_ptr(), + rec_state.data_ptr(), + lin_steps[0, lin_rank].data_ptr(), + lin_steps.stride(0), + attn_out_K.data_ptr(), + K, 48, a_stride, b_stride, True, s, + ) + else: + fvk.qwen36_gdn_chunk_from_conv_smem_strided_bf16( + conv_out_K.data_ptr(), + a_vec_K.data_ptr(), b_vec_K.data_ptr(), + lw['neg_A_log_exp_fp32_t'].data_ptr(), + lw['dt_bias_fp32_t'].data_ptr(), + rec_state.data_ptr(), + attn_out_K.data_ptr(), + K, 48, a_stride, b_stride, True, s, + ) # (10) rms_norm_gated_silu @ M=K*48, dim=128. attn_out_flat = attn_out_K.view(K * 48, 128) From 51bc3a614e1db93928ecadb463a9ae62577c738c Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Sat, 4 Jul 2026 06:08:14 -0400 Subject: [PATCH 10/11] docs(qwen36): relaxed thinking acceptance and chunk-saves knobs Document the two opt-in DFlash performance modes: relaxed thinking-phase acceptance (env knobs, +43% measured on a thinking-enabled robot-plan prompt, transcript-exactness tradeoff stated) and the Thor chunk-saves verify kernels (~5% cycle, kernel family vs the parity reference stated). Both default off; the default configuration remains token-identical to the FP8-KV MTP reference. --- docs/qwen36_dflash.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/qwen36_dflash.md b/docs/qwen36_dflash.md index 6c2f883c..416d900b 100644 --- a/docs/qwen36_dflash.md +++ b/docs/qwen36_dflash.md @@ -106,6 +106,39 @@ Output quality is lossless: the verify pass is the greedy ground truth, and generated tokens are byte-identical to the FP8-KV MTP reference on all measured prompts. +## Relaxed thinking-phase acceptance (opt-in) + +Qwen3.6 reasons inside a `` block before answering, and the +thinking stream dominates the token budget. Mirroring the +TensorRT-LLM MTP policy, relaxed acceptance treats a draft as accepted +inside the think block when it is in the verify logits' top-k and +within a logit margin of the argmax; the accepted token is the draft +itself. Rows from the first draft that closes the think block fall +back to strict matching, so everything after `` — the visible +answer — remains exact-verified greedy. + +| Env | Default | Meaning | +|---|---|---| +| `FLASHRT_QWEN36_DFLASH_RELAXED_THINKING` | `0` | Enable relaxed acceptance inside ``. | +| `FLASHRT_QWEN36_DFLASH_RELAXED_TOPK` | `3` | Candidate set size. | +| `FLASHRT_QWEN36_DFLASH_RELAXED_DELTA` | `1.0` | Logit margin vs the argmax (equals a log-prob margin). | + +Measured on Thor (thinking-enabled robot JSON-plan prompt, steady +state): AL 3.78 -> 5.42, **40.4 -> 57.7 tok/s (+43%)**. Prompts whose +drafts rarely reach the top-k see no change (a math prompt measured +neutral). The thinking stream is no longer token-identical to the +strict run — enable this only where the product gates on the final +answer, not the reasoning transcript. + +## Opt-in chunk-saves verify kernels (Thor) + +`FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES=1` routes the DFlash verify's +linear-attention layers to chunk kernels that emit the per-step +rollback checkpoints in one pass (~5% lower cycle time). This moves +the verify off the kernel family that the MTP reference path uses, so +greedy output is no longer token-identical to that reference — same +tradeoff class as relaxed acceptance. Default off. + ## Serving A stateless OpenAI-compatible host for this path lives in From d3eba3d47763d7ebf2cdba137c5f3060be8cc9d1 Mon Sep 17 00:00:00 2001 From: LiangSu8899 Date: Sat, 4 Jul 2026 13:01:33 -0400 Subject: [PATCH 11/11] docs(qwen36): drop the README news entry for DFlash Keep the DFlash announcement out of the top-level README for now; the feature remains fully documented in docs/qwen36_dflash.md and the serving README. The news entry can land separately once the release messaging is settled. --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index fa60db7a..a503c5f4 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,6 @@ See [Supported Models](#supported-models), [Hardware Support](#hardware-support) ## News -- [2026/07] **DFlash block-diffusion speculative decoding** for Qwen3.6-27B NVFP4 on Jetson AGX Thor: one 15-token drafter block per cycle, constant-time partial-accept rollback, and a per-token drafter context window with prompt-tail seeding — **48.9 tok/s on structured robot-plan prompts vs 33.7 tok/s for the MTP chain (+45%)**, lossless greedy output. See [Qwen3.6 DFlash](docs/qwen36_dflash.md). - [2026/06] **Higgs Audio v3 TTS-4B** lands on FlashRT with a kernelized FP8/BF16 decode path, streaming-friendly generation API, and a FastAPI serving host. See [Higgs usage](docs/higgs_audio_v3.md#3-quickstart), [Higgs performance](docs/higgs_audio_v3.md#performance), and [Higgs serving](serving/higgs_audio_agent/README.md). - [2026/06] **FlashRT HF Kernels** are available as Hugging Face Kernel Hub packages under the `flashrt` namespace. See [LiangSu8899/FlashRT-HF-kernels](https://github.com/LiangSu8899/FlashRT-HF-kernels) and [huggingface.co/flashrt](https://huggingface.co/flashrt). - [2026/06] The `serving/` layer is documented as the scenario-host layer for OpenAI-compatible LLM/audio serving and robot execution-state hosts. See [serving README](serving/README.md), [serving design](docs/serving_design.md), and [architecture](docs/architecture.md).