diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index 8616797..9f09c59 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -4157,6 +4157,25 @@ PYBIND11_MODULE(flash_rt_kernels, m) { py::arg("B"), py::arg("S"), py::arg("conv_dim"), py::arg("k"), py::arg("apply_silu") = true, py::arg("stream") = 0); + m.def("causal_conv1d_qwen36_update_chunk_saves_bf16", + [](uintptr_t x, uintptr_t w, uintptr_t bias, + uintptr_t out, uintptr_t state, + uintptr_t state_steps, int64_t step_stride, + int B, int S, int conv_dim, int k, bool apply_silu, + uintptr_t stream) { + flash_rt::kernels::causal_conv1d_qwen36_update_chunk_saves_bf16( + to_ptr(x), to_ptr(w), + bias ? to_ptr(bias) : nullptr, + to_ptr(out), to_ptr(state), + to_ptr(state_steps), step_stride, + B, S, conv_dim, k, apply_silu, to_stream(stream)); + }, + py::arg("x"), py::arg("w"), py::arg("bias"), + py::arg("out"), py::arg("state"), + py::arg("state_steps"), py::arg("step_stride"), + py::arg("B"), py::arg("S"), py::arg("conv_dim"), py::arg("k"), + py::arg("apply_silu") = true, py::arg("stream") = 0); + m.def("causal_conv1d_qwen36_update_chunk_parallel_bf16", [](uintptr_t x, uintptr_t w, uintptr_t bias, uintptr_t out, uintptr_t state, @@ -4876,6 +4895,31 @@ PYBIND11_MODULE(flash_rt_kernels, m) { py::arg("a_stride"), py::arg("b_stride"), py::arg("use_qk_l2norm") = true, py::arg("stream") = 0); + m.def("qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16", + [](uintptr_t conv_out, uintptr_t a, uintptr_t b, + uintptr_t neg_exp_A_log, uintptr_t dt_bias, + uintptr_t state, uintptr_t state_steps, int64_t step_stride, + uintptr_t out, + int S, int num_v_heads, int a_stride, int b_stride, + bool use_qk_l2norm, uintptr_t stream) { + flash_rt::kernels:: + qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + to_ptr(conv_out), to_ptr(a), to_ptr(b), + reinterpret_cast(neg_exp_A_log), + reinterpret_cast(dt_bias), + to_ptr(state), to_ptr(state_steps), step_stride, + to_ptr(out), + S, num_v_heads, a_stride, b_stride, + use_qk_l2norm, to_stream(stream)); + }, + py::arg("conv_out"), py::arg("a"), py::arg("b"), + py::arg("neg_exp_A_log"), py::arg("dt_bias"), + py::arg("state"), py::arg("state_steps"), + py::arg("step_stride"), py::arg("out"), + py::arg("S"), py::arg("num_v_heads"), + py::arg("a_stride"), py::arg("b_stride"), + py::arg("use_qk_l2norm") = true, py::arg("stream") = 0); + m.def("qwen36_gdn_wy_norm_cumsum_bf16", [](uintptr_t q16, uintptr_t k16, uintptr_t g, uintptr_t q16_l2, uintptr_t k16_l2, uintptr_t g_cumsum, diff --git a/csrc/kernels/causal_conv1d_qwen36.cu b/csrc/kernels/causal_conv1d_qwen36.cu index f4e1c69..7aa15f0 100644 --- a/csrc/kernels/causal_conv1d_qwen36.cu +++ b/csrc/kernels/causal_conv1d_qwen36.cu @@ -193,6 +193,84 @@ __global__ void causal_conv1d_update_chunk_kernel( } } +// Per-step-checkpoint variant of the chunk kernel above: identical +// math (the carried window values are bf16-exact in fp32 registers), +// plus a bf16 dump of the post-shift state after every step into +// ``state_steps`` (step s at state_steps + s * step_stride). Slot s +// byte-matches the committed state of an S = s + 1 run, which is what +// the spec-decode partial-accept rollback copies. +__global__ void causal_conv1d_update_chunk_saves_kernel( + const __nv_bfloat16* __restrict__ x, + const __nv_bfloat16* __restrict__ w, + const __nv_bfloat16* __restrict__ bias, + __nv_bfloat16* __restrict__ out, + __nv_bfloat16* __restrict__ state, + __nv_bfloat16* __restrict__ state_steps, + int64_t step_stride, + int B, int S, int conv_dim, int k, + bool apply_silu) +{ + const int c = blockIdx.x * kThreadsX + threadIdx.x; + const int b = blockIdx.y; + if (c >= conv_dim) return; + + const int sk = k - 1; + const int state_base = (b * conv_dim + c) * sk; + + float wv[kMaxK]; + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + wv[i] = (i < k) ? static_cast(w[c * k + i]) : 0.0f; + } + + float sv[kMaxK]; + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + sv[i] = (i < sk) + ? static_cast(state[state_base + i]) + : 0.0f; + } + + for (int s = 0; s < S; ++s) { + const float x_v = static_cast( + x[(size_t)b * S * conv_dim + (size_t)s * conv_dim + c]); + + float acc = (bias != nullptr) ? static_cast(bias[c]) : 0.0f; + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + if (i < sk) acc = fmaf(sv[i], wv[i], acc); + } + acc = fmaf(x_v, wv[sk], acc); + + if (apply_silu) acc = silu(acc); + out[(size_t)b * S * conv_dim + (size_t)s * conv_dim + c] = + __float2bfloat16(acc); + + #pragma unroll + for (int i = 0; i < kMaxK - 1; ++i) { + if (i < sk - 1) sv[i] = sv[i + 1]; + } + if (sk >= 1) { + sv[sk - 1] = x_v; + } + + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + if (i < sk) { + state_steps[(size_t)s * step_stride + state_base + i] = + __float2bfloat16(sv[i]); + } + } + } + + #pragma unroll + for (int i = 0; i < kMaxK; ++i) { + if (i < sk) { + state[state_base + i] = __float2bfloat16(sv[i]); + } + } +} + __global__ void causal_conv1d_update_chunk_parallel_kernel( const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ w, @@ -360,6 +438,27 @@ void causal_conv1d_qwen36_update_chunk_bf16( B, S, conv_dim, k, apply_silu); } +void causal_conv1d_qwen36_update_chunk_saves_bf16( + const void* x, const void* w, const void* bias, + void* out, void* state, + void* state_steps, int64_t step_stride, + int B, int S, int conv_dim, int k, + bool apply_silu, + cudaStream_t stream) +{ + dim3 grid((conv_dim + kThreadsX - 1) / kThreadsX, B); + dim3 block(kThreadsX); + causal_conv1d_update_chunk_saves_kernel<<>>( + reinterpret_cast(x), + reinterpret_cast(w), + reinterpret_cast(bias), + reinterpret_cast<__nv_bfloat16*>(out), + reinterpret_cast<__nv_bfloat16*>(state), + reinterpret_cast<__nv_bfloat16*>(state_steps), + step_stride, + B, S, conv_dim, k, apply_silu); +} + void causal_conv1d_qwen36_update_chunk_parallel_bf16( const void* x, const void* w, const void* bias, void* out, void* state, diff --git a/csrc/kernels/causal_conv1d_qwen36.cuh b/csrc/kernels/causal_conv1d_qwen36.cuh index e0e2243..fa6e9f1 100644 --- a/csrc/kernels/causal_conv1d_qwen36.cuh +++ b/csrc/kernels/causal_conv1d_qwen36.cuh @@ -90,6 +90,21 @@ void causal_conv1d_qwen36_update_chunk_bf16( bool apply_silu, cudaStream_t stream); +// Chunk variant with per-step state checkpoints: dumps the post-step +// conv state to state_steps + s * step_stride for every step s, for +// the spec-decode partial-accept rollback. +void causal_conv1d_qwen36_update_chunk_saves_bf16( + const void* x, + const void* w, + const void* bias, + void* out, + void* state, + void* state_steps, + int64_t step_stride, + int B, int S, int conv_dim, int k, + bool apply_silu, + cudaStream_t stream); + // Parallel prefill variant: computes each (S, channel) output // independently, then updates the final state in a second tiny kernel. // This trades extra global loads for much higher S-dimension diff --git a/csrc/kernels/gated_deltanet_qwen36.cu b/csrc/kernels/gated_deltanet_qwen36.cu index c875e2c..c80754c 100644 --- a/csrc/kernels/gated_deltanet_qwen36.cu +++ b/csrc/kernels/gated_deltanet_qwen36.cu @@ -896,6 +896,134 @@ __global__ void qwen36_gdn_chunk_from_conv_smem_kernel( } } +// Per-step-checkpoint variant of the chunk kernel above: identical +// math and rounding cadence (the state is rounded to bf16 after every +// step exactly as the original does between steps), plus a dump of +// each step's rounded state into ``state_steps`` (step s at +// state_steps + s * step_stride). Slot s byte-matches the committed +// state of an S = s + 1 run, which is what the spec-decode +// partial-accept rollback copies. +template +__global__ void qwen36_gdn_chunk_from_conv_smem_saves_kernel( + const __nv_bfloat16* __restrict__ conv_out, + const __nv_bfloat16* __restrict__ a_in, + const __nv_bfloat16* __restrict__ b_in, + const float* __restrict__ neg_exp_A_log, + const float* __restrict__ dt_bias, + __nv_bfloat16* __restrict__ state, + __nv_bfloat16* __restrict__ state_steps, + int64_t step_stride, + __nv_bfloat16* __restrict__ out_, + int S, + int num_v_heads, + int a_stride, + int b_stride, + bool use_qk_l2norm) +{ + static_assert(HD == 128, "HD must be 128 for Qwen3.6"); + const int h = blockIdx.x; + const int b = blockIdx.y; + const int t = threadIdx.x; + if (t >= HD) return; + + extern __shared__ float smem[]; + float* state_s = smem; + float* qs = state_s + HD * HD; + float* ks = qs + HD; + float* scratch = ks + HD; + + const size_t state_h_off = + (((size_t)b * num_v_heads + h)) * HD * HD; + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state_s[i * HD + t] = static_cast( + state[state_h_off + (size_t)i * HD + t]); + } + __syncthreads(); + + const int src_h = h / 3; + for (int s = 0; s < S; ++s) { + const size_t row = static_cast(s) * 10240; + const size_t out_off = ((size_t)s * num_v_heads + h) * HD + t; + qs[t] = static_cast(conv_out[row + src_h * HD + t]); + ks[t] = static_cast(conv_out[row + 2048 + src_h * HD + t]); + __syncthreads(); + + if (use_qk_l2norm) { + float q_sq = qs[t] * qs[t]; + float k_sq = ks[t] * ks[t]; + q_sq = block_reduce_sum(q_sq, scratch); + // See the non-saves kernel for why this barrier is required + // between the two block reductions sharing ``scratch``. + __syncthreads(); + k_sq = block_reduce_sum(k_sq, scratch); + const float q_inv = rsqrtf(q_sq + kEps); + const float k_inv = rsqrtf(k_sq + kEps); + qs[t] *= q_inv; + ks[t] *= k_inv; + __syncthreads(); + } + + qs[t] *= rsqrtf(static_cast(HD)); + __syncthreads(); + + const float av = + static_cast(a_in[s * a_stride + h]) + dt_bias[h]; + const float sp = log1pf(__expf(av)); + const float g_log = static_cast( + __float2bfloat16(neg_exp_A_log[h] * sp)); + const float g_t = __expf(g_log); + const float bv = static_cast(b_in[s * b_stride + h]); + const float beta_t = static_cast( + __float2bfloat16(1.0f / (1.0f + __expf(-bv)))); + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state_s[i * HD + t] *= g_t; + } + + float kv_mem = 0.0f; + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + kv_mem = fmaf(state_s[i * HD + t], ks[i], kv_mem); + } + + const float v_t = + static_cast(conv_out[row + 4096 + h * HD + t]); + const float delta = (v_t - kv_mem) * beta_t; + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state_s[i * HD + t] = + fmaf(ks[i], delta, state_s[i * HD + t]); + } + + float out_t = 0.0f; + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + out_t = fmaf(state_s[i * HD + t], qs[i], out_t); + } + out_[out_off] = __float2bfloat16(out_t); + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + const __nv_bfloat16 v = + __float2bfloat16(state_s[i * HD + t]); + state_steps[ + (size_t)s * step_stride + state_h_off + (size_t)i * HD + t] = + v; + state_s[i * HD + t] = static_cast(v); + } + __syncthreads(); + } + + #pragma unroll 16 + for (int i = 0; i < HD; ++i) { + state[state_h_off + (size_t)i * HD + t] = + __float2bfloat16(state_s[i * HD + t]); + } +} + __global__ void qwen36_gdn_wy_norm_qk_kernel( const __nv_bfloat16* __restrict__ q16, const __nv_bfloat16* __restrict__ k16, @@ -1469,6 +1597,50 @@ void qwen36_gdn_chunk_from_conv_smem_strided_bf16( S, num_v_heads, a_stride, b_stride, use_qk_l2norm); } +void qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + const void* conv_out, + const void* a, + const void* b, + const float* neg_exp_A_log, + const float* dt_bias, + void* state, + void* state_steps, + int64_t step_stride, + void* out, + int S, + int num_v_heads, + int a_stride, + int b_stride, + bool use_qk_l2norm, + cudaStream_t stream) +{ + if (S <= 0 || num_v_heads <= 0) return; + dim3 grid(num_v_heads, 1); + dim3 block(kHD); + constexpr size_t kSmemBytes = + (kHD * kHD + 2 * kHD + 32) * sizeof(float); + static bool attr_set = false; + if (!attr_set) { + cudaFuncSetAttribute( + qwen36_gdn_chunk_from_conv_smem_saves_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(kSmemBytes)); + attr_set = true; + } + qwen36_gdn_chunk_from_conv_smem_saves_kernel<<< + grid, block, kSmemBytes, stream>>>( + reinterpret_cast(conv_out), + reinterpret_cast(a), + reinterpret_cast(b), + neg_exp_A_log, + dt_bias, + reinterpret_cast<__nv_bfloat16*>(state), + reinterpret_cast<__nv_bfloat16*>(state_steps), + step_stride, + reinterpret_cast<__nv_bfloat16*>(out), + S, num_v_heads, a_stride, b_stride, use_qk_l2norm); +} + void gated_deltanet_chunk_smem_qwen36_bf16( const void* q, const void* k, diff --git a/csrc/kernels/gated_deltanet_qwen36.cuh b/csrc/kernels/gated_deltanet_qwen36.cuh index d53de34..f497eff 100644 --- a/csrc/kernels/gated_deltanet_qwen36.cuh +++ b/csrc/kernels/gated_deltanet_qwen36.cuh @@ -216,6 +216,26 @@ void qwen36_gdn_chunk_from_conv_smem_strided_bf16( bool use_qk_l2norm, cudaStream_t stream); +// Chunk variant with per-step state checkpoints: dumps the post-step +// (bf16-rounded) recurrent state to state_steps + s * step_stride for +// every step s, for the spec-decode partial-accept rollback. +void qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + const void* conv_out, + const void* a, + const void* b, + const float* neg_exp_A_log, + const float* dt_bias, + void* state, + void* state_steps, + int64_t step_stride, + void* out, + int S, + int num_v_heads, + int a_stride, + int b_stride, + bool use_qk_l2norm, + cudaStream_t stream); + // Chunk/WY Gated DeltaNet building blocks. These are the native // FlashRT replacement path for the Python/Triton FLA chunk implementation. // First specialization targets Qwen3.6 shapes: diff --git a/docs/qwen36_dflash.md b/docs/qwen36_dflash.md new file mode 100644 index 0000000..416d900 --- /dev/null +++ b/docs/qwen36_dflash.md @@ -0,0 +1,173 @@ +# Qwen3.6-27B DFlash Speculative Decoding + +This document covers the DFlash block-diffusion drafter path for +Qwen3.6-27B NVFP4. DFlash replaces the sequential MTP draft chain with +a single drafter forward per speculation cycle: a 5-layer 2B drafter +proposes an entire 15-token block, and the target model verifies the +block in one S=16 forward. + +For the general Qwen3.6 NVFP4 model contract and parameter reference, +see [`qwen36_nvfp4.md`](qwen36_nvfp4.md) and +[`qwen36_usage.md`](qwen36_usage.md). + +## Requirements + +- Qwen3.6-27B NVFP4 main checkpoint (same as the MTP path). +- The z-lab DFlash drafter checkpoint: + +```bash +hf download z-lab/Qwen3.6-27B-DFlash --local-dir /models/Qwen3.6-27B-DFlash +``` + + The drafter ships as a single BF16 `model.safetensors` (~3.3 GB, + 5 layers, `block_size=16`, target hidden taps at layers + 1/16/31/46/61). FlashRT quantizes every drafter linear to NVFP4 at + load time (~825 MB resident); no separate conversion step. +- On Thor the DFlash verify runs over the persistent FP8 KV cache. + The frontend allocates it automatically at drafter load if the + construction did not already enable long-context mode. + +## Usage + +```python +import os + +from flash_rt.frontends.torch.qwen36_thor import Qwen36TorchFrontendThor + +os.environ["FLASHRT_QWEN36_MTP_CKPT_DIR"] = "/models/Qwen3.6-27B-FP8" +os.environ["FLASHRT_QWEN36_DFLASH_CKPT_DIR"] = "/models/Qwen3.6-27B-DFlash" +os.environ["FLASHRT_QWEN36_LONG_KV_CACHE"] = "fp8" + +fe = Qwen36TorchFrontendThor( + "/models/Qwen3.6-27B-NVFP4", + quant="nvfp4", + max_seq=32768, +) +fe.init_dflash_drafter() # reads FLASHRT_QWEN36_DFLASH_CKPT_DIR + +ids = fe._tokenizer.apply_chat_template( + [{"role": "user", "content": "Plan the pick-and-place task."}], + add_generation_prompt=True, return_tensors="pt").to(fe.device) + +out = fe.generate_own_speculative_DFlash_nvfp4( + ids, + max_new_tokens=256, + K=15, # speculative tokens per cycle +) +``` + +The RTX frontend exposes the same entry point; the drafter and verify +kernels are shared, only the KV plumbing differs per arch. + +## Drafter context window + +The drafter conditions on fc-projected target hidden features of the +committed context. Two window modes exist: + +- **Per-token window** (Thor default): one feature entry per committed + token, appended in bulk after each verify (N+1 entries per cycle). + On Thor the prompt prefill seeds the window with the features of the + last `min(window, prompt_len)` prompt tokens, so the drafter starts + at full context instead of ramping from empty. +- **Per-cycle shift window** (legacy, RTX default): one entry per + speculation cycle. Kept for compatibility; acceptance length is + measurably lower because window entries end up ~AL tokens apart. + +| Env | Default | Meaning | +|---|---|---| +| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | unset | Drafter checkpoint directory (required). | +| `FLASHRT_QWEN36_DFLASH_PERTOKEN` | `1` on Thor | Per-token window mode. | +| `FLASHRT_QWEN36_DFLASH_WINDOW` | `128` | Per-token window length (tokens, <= 256). | +| `FLASHRT_QWEN36_DFLASH_WINDOW_SEED` | `1` | Seed the window from the prompt tail at prefill (Thor). | + +## Measured performance (Thor, SM110) + +Steady-state decode at short context against the FP8-KV MTP spec path +(`generate_own_speculative_KN_nvfp4`, K=6) in the same process, greedy +decoding, 64/256-token delta method: + +| prompt | MTP AL / tok/s | DFlash AL / tok/s | +|---|---:|---:| +| robot task -> JSON plan | 2.87 / 33.7 | **4.57 / 48.9** | +| robot navigation plan | 2.59 / 30.5 | 3.25 / 34.8 | +| prose explanation | 2.43 / 28.5 | 3.00 / 31.7 | + +Cycle anatomy on Thor: one S=16 verify (~86 ms, weight-read bound) + +one drafter graph replay (~7 ms). A partial accept costs two +constant-time state copies from the per-step checkpoints written +during the verify itself — there is no recovery forward. The accept +decision includes one host synchronization per cycle +(`argmin().item()` on the match mask); at ~10 us it is three orders +of magnitude below the verify cost and is included in every number +above. A device-side accept loop is possible follow-up work, not a +prerequisite. + +Output quality is lossless: the verify pass is the greedy ground +truth, and generated tokens are byte-identical to the FP8-KV MTP +reference on all measured prompts. + +## Relaxed thinking-phase acceptance (opt-in) + +Qwen3.6 reasons inside a `` block before answering, and the +thinking stream dominates the token budget. Mirroring the +TensorRT-LLM MTP policy, relaxed acceptance treats a draft as accepted +inside the think block when it is in the verify logits' top-k and +within a logit margin of the argmax; the accepted token is the draft +itself. Rows from the first draft that closes the think block fall +back to strict matching, so everything after `` — the visible +answer — remains exact-verified greedy. + +| Env | Default | Meaning | +|---|---|---| +| `FLASHRT_QWEN36_DFLASH_RELAXED_THINKING` | `0` | Enable relaxed acceptance inside ``. | +| `FLASHRT_QWEN36_DFLASH_RELAXED_TOPK` | `3` | Candidate set size. | +| `FLASHRT_QWEN36_DFLASH_RELAXED_DELTA` | `1.0` | Logit margin vs the argmax (equals a log-prob margin). | + +Measured on Thor (thinking-enabled robot JSON-plan prompt, steady +state): AL 3.78 -> 5.42, **40.4 -> 57.7 tok/s (+43%)**. Prompts whose +drafts rarely reach the top-k see no change (a math prompt measured +neutral). The thinking stream is no longer token-identical to the +strict run — enable this only where the product gates on the final +answer, not the reasoning transcript. + +## Opt-in chunk-saves verify kernels (Thor) + +`FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES=1` routes the DFlash verify's +linear-attention layers to chunk kernels that emit the per-step +rollback checkpoints in one pass (~5% lower cycle time). This moves +the verify off the kernel family that the MTP reference path uses, so +greedy output is no longer token-identical to that reference — same +tradeoff class as relaxed acceptance. Default off. + +## Serving + +A stateless OpenAI-compatible host for this path lives in +[`serving/qwen36_dflash_agent`](../serving/qwen36_dflash_agent) — +single-stream request/response serving with per-request DFlash +generation and accept-length telemetry: + +```bash +python -m serving.qwen36_dflash_agent.server \ + --checkpoint /models/Qwen3.6-27B-NVFP4 --max-seq 32768 --K 15 +curl -s http://127.0.0.1:8000/health +``` + +For long-running agent sessions (prefix reuse, tool calling, SSE +streaming) use [`serving/qwen36_agent`](../serving/qwen36_agent), +which serves the MTP spec path. + +## Notes + +- Structured output (JSON plans, code) accepts much better than free + prose; the gains above track the drafter's training distribution. +- Degenerate prompts that repeat one sentence verbatim can steer the + seeded window into drafting more repetition. If you benchmark with + synthetic repeated text, disable the seed + (`FLASHRT_QWEN36_DFLASH_WINDOW_SEED=0`) for representative numbers. +- Greedy-parity comparisons must use the FP8-KV MTP route as the + reference (`FLASHRT_QWEN36_LONG_CTX_ROUTE_MIN_SEQ=0` forces it for + short prompts). The BF16 short route stores KV in a different + format, so token-exact comparison across the two is not meaningful. +- The published drafter checkpoint is marked by z-lab as still under + training; acceptance lengths should improve by dropping in a newer + drafter checkpoint without code changes. diff --git a/docs/qwen36_usage.md b/docs/qwen36_usage.md index 308b0f0..3efc273 100644 --- a/docs/qwen36_usage.md +++ b/docs/qwen36_usage.md @@ -186,7 +186,10 @@ frontend is built has no effect. | `FLASHRT_QWEN36_MTP_CKPT_DIR` | Required for spec decode | unset | Directory containing `mtp.safetensors` (FP8 e4m3 block-128) from a paired Qwen3.6-Next-27B-FP8 ckpt. Loaded once at construction and converted FP8 → BF16 → NVFP4. If unset, MTP is `None` and `generate_own_speculative_KN_nvfp4` raises; pure-decode still works. | | `FLASHRT_QWEN36_MTP_KEEP_BF16` | Optional | BF16-source MTP: `1`; FP8-source MTP: n/a | For community BF16/native MTP checkpoints, keep BF16 projection weights and use them in the drafter hot path. This improves MTP alignment at the cost of extra VRAM. Set `0` to force the lower-memory NVFP4-converted MTP path. | | `FLASHRT_QWEN36_HF_PATCH` | Optional | unset | Path to a HF FP8 dispatch monkey-patch script. Only consulted by the legacy FP8 path; the NVFP4 path doesn't need it. If unset or path doesn't exist, the patch step is silently skipped. | -| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | Optional | unset | Drafter ckpt directory for the DFlash add-on path. Required only if you call `init_dflash_drafter()`; raises a clear error if unset and `ckpt_dir` is also not passed. | +| `FLASHRT_QWEN36_DFLASH_CKPT_DIR` | Optional | unset | Drafter ckpt directory for the DFlash path. Required only if you call `init_dflash_drafter()`; raises a clear error if unset and `ckpt_dir` is also not passed. See [`qwen36_dflash.md`](qwen36_dflash.md). | +| `FLASHRT_QWEN36_DFLASH_PERTOKEN` | Optional | `1` on Thor | Per-token drafter context window (one feature entry per committed token). `0` falls back to the legacy per-cycle shift window. See [`qwen36_dflash.md`](qwen36_dflash.md). | +| `FLASHRT_QWEN36_DFLASH_WINDOW` | Optional | `128` | Per-token drafter window length in tokens (max 256). | +| `FLASHRT_QWEN36_DFLASH_WINDOW_SEED` | Optional | `1` | Seed the per-token window from the prompt tail during Thor prefill. Disable for benchmarks built from verbatim-repeated text. | | `FLASHRT_QWEN36_MAX_Q_SEQ` | Optional | `2048` | Maximum S=K working-set rows for verify/prefill buffers. Long prefill chunking is additionally capped by the retained BF16 working window. | | `FLASHRT_QWEN36_LONG_CTX_BF16_WINDOW` | Optional | `min(2048, MAX_Q_SEQ)` | Retained BF16 working-window rows in long-context mode. Raising this can enable larger prompt chunks but costs substantial VRAM. | | `FLASHRT_QWEN36_LONG_CTX_ROUTE_MIN_SEQ` | Optional | `512` in long-ctx mode | Prompt length at or above which a long-context frontend routes through the chunked compressed-KV path. The measured 128-token bucket is also routed through FP8-KV to avoid the legacy one-token BF16/spec prefill. Other short prompts stay on BF16/spec unless the full request exceeds the retained BF16 window. | diff --git a/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py b/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py index ca77de4..147ebd6 100644 --- a/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py +++ b/flash_rt/frontends/torch/_qwen36_rtx_dflash_forward.py @@ -799,3 +799,142 @@ def dflash_drafter_forward_capture(frontend) -> torch.Tensor: buf['logits'].data_ptr(), M, VOCAB, H, s, widen=True) return buf['logits'] + + +# ==================================================================== +# Per-token window variant +# ==================================================================== +# +# The shift-window above appends ONE fc-projected tap set per spec +# cycle, so window entries are ~AL committed tokens apart while the +# drafter attends to them at consecutive positions. The per-token +# variant keeps a window of features for EVERY committed token: the +# orchestration appends N+1 entries after each accept (and seeds the +# window from the prompt tail at prefill), and the drafter forward +# below only READS the window — no fc, no shift — which also makes the +# graph capture side-effect free. + +def alloc_pertoken_window(frontend, win: int) -> None: + """Allocate the per-token feature window + append scratch.""" + buf = frontend._dflash_buf + if buf.get('pt_window') is not None and buf['pt_win'] == win: + return + if win > buf['max_ctx']: + raise ValueError( + f'window {win} exceeds drafter max_ctx {buf["max_ctx"]}') + H = buf['hidden'] + dev = frontend.device + buf['pt_window'] = torch.zeros( + win, H, dtype=torch.bfloat16, device=dev) + buf['pt_shift_scratch'] = torch.empty_like(buf['pt_window']) + buf['pt_proj_out'] = torch.empty( + buf['max_ctx'], H, dtype=torch.bfloat16, device=dev) + buf['pt_taps_rows'] = torch.empty( + max(buf['block'], win), 5, H, dtype=torch.bfloat16, device=dev) + buf['pt_seed_taps'] = torch.empty( + 5, win, H, dtype=torch.bfloat16, device=dev) + buf['pt_win'] = win + buf['pt_valid'] = 0 + + +def reset_pertoken_window(frontend) -> None: + """Clear per-token window state. Call at the start of a generate.""" + buf = frontend._dflash_buf + if buf.get('pt_window') is not None: + buf['pt_window'].zero_() + buf['pt_valid'] = 0 + + +def pertoken_window_append(frontend, taps_rows) -> None: + """Append fc-projected features of R committed rows to the window. + + taps_rows: (R, 5, hidden) bf16 — verify tap_buf rows of the + committed tokens, oldest first. Shift-left by R, write the R new + features at the tail. Runs eagerly on the current stream, outside + the drafter graph. + """ + from flash_rt import flash_rt_kernels as fvk + + buf = frontend._dflash_buf + d = frontend._weights.ptrs['dflash'] + s = torch.cuda.current_stream().cuda_stream + H = buf['hidden'] + FC_IN = buf['fc_in'] + eps = float(d['rms_norm_eps']) + win = buf['pt_window'] + W = buf['pt_win'] + R = int(taps_rows.shape[0]) + if R > W: + taps_rows = taps_rows[-W:] + R = W + + x = taps_rows.reshape(R, FC_IN).contiguous() + ap_t, sf_t = buf['act_Mctx_K5120'] + _quant_act(fvk, x, ap_t, sf_t, R, FC_IN, s) + _gemm_nvfp4(fvk, ap_t.data_ptr(), sf_t.data_ptr(), + d['fc_packed'], d['fc_sf'], d['fc_alpha'], + buf['pt_proj_out'].data_ptr(), R, H, FC_IN, s) + if R < W: + scratch = buf['pt_shift_scratch'] + scratch[:W - R].copy_(win[R:]) + win[:W - R].copy_(scratch[:W - R]) + fvk.rms_norm( + int(buf['pt_proj_out'].data_ptr()), int(d['hidden_norm_w']), + int(win[W - R:W].data_ptr()), + R, H, eps, int(s), + ) + buf['pt_valid'] = min(buf['pt_valid'] + R, W) + + +def dflash_drafter_forward_pertoken(frontend, + valid_ctx: int | None = None): + """Drafter forward over the per-token window (read-only). + + valid_ctx: number of valid tail rows to attend to. None means the + full window — the shape the captured graph bakes in. Callers pass + the actual valid count during ramp-up (window not yet full). + + Returns: logits (block, vocab) bf16 in buf['logits']. + """ + from flash_rt import flash_rt_kernels as fvk + + s = torch.cuda.current_stream().cuda_stream + buf = frontend._dflash_buf + d = frontend._weights.ptrs['dflash'] + M = buf['block'] + H = buf['hidden'] + VOCAB = buf['vocab'] + eps = float(d['rms_norm_eps']) + W = buf['pt_win'] + ctx_len = W if valid_ctx is None else int(valid_ctx) + if not (1 <= ctx_len <= W): + raise ValueError(f'valid_ctx={ctx_len} out of [1, {W}]') + win = buf['pt_window'][W - ctx_len:W] + + fvk.qwen36_embedding_lookup_bf16( + buf['ids_static'].data_ptr(), + int(frontend._weights.ptrs['embed_w']), + buf['embed_buf'].data_ptr(), M, H, s, + ) + fvk.gpu_copy( + buf['h_b'].data_ptr(), buf['embed_buf'].data_ptr(), + M * H * 2, s, + ) + h = buf['h_b'] + for L in range(buf['n_layers']): + h = _drafter_layer_forward( + frontend, fvk, L, h, win, ctx_len, s) + fvk.rms_norm( + int(h.data_ptr()), int(d['final_norm_w']), + int(buf['h_final_norm'].data_ptr()), + M, H, eps, int(s), + ) + ap_lm, sf_lm = buf['act_M16_K5120'] + _quant_act(fvk, buf['h_final_norm'], ap_lm, sf_lm, M, H, s) + _gemm_nvfp4(fvk, ap_lm.data_ptr(), sf_lm.data_ptr(), + frontend._weights.ptrs['lm_head_packed'], + frontend._weights.ptrs['lm_head_sf'], + frontend._weights.ptrs['lm_head_alpha'], + buf['logits'].data_ptr(), + M, VOCAB, H, s, widen=True) + return buf['logits'] diff --git a/flash_rt/frontends/torch/qwen36_rtx.py b/flash_rt/frontends/torch/qwen36_rtx.py index 7fa00cb..c129c9d 100644 --- a/flash_rt/frontends/torch/qwen36_rtx.py +++ b/flash_rt/frontends/torch/qwen36_rtx.py @@ -11614,6 +11614,14 @@ def _tq_inject_kv(self, full_rank: int, cur_pos: int, # N6-A4: DFlash spec decode (block-diffusion drafter + chain verify) # ================================================================== + def init_dflash_drafter(self, ckpt_dir: str | None = None) -> None: + """Public entry: load the DFlash drafter for spec decode. + + ``ckpt_dir`` falls back to ``FLASHRT_QWEN36_DFLASH_CKPT_DIR``. + Must be called before ``generate_own_speculative_DFlash_nvfp4``. + """ + self._load_dflash_drafter(ckpt_dir) + def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: """Load the z-lab/Qwen3.6-27B-DFlash drafter (NVFP4 W4A16). @@ -11736,8 +11744,167 @@ def _restore(): self._captured_drafter_graphs_dflash, eff_ctx, g) return g + def _dflash_verify_forward_K(self, token_ids_K, cos_K, sin_K, + cur_pos: int, K: int, tap_buf): + """Arch hook: the S=K verify forward used by DFlash spec decode. + + Default is the BF16-staged KV verify. Subclasses whose K-row + layer path requires a different KV mode (Thor: FP8-KV) override + this with the matching wrapper; the DFlash orchestration and + graph capture above it stay shared. + """ + return self.forward_own_decode_K_nvfp4( + token_ids_K, cos_K, sin_K, cur_pos, K, tap_buf=tap_buf) + + def _dflash_prefill_nvfp4(self, input_ids): + """Arch hook: prompt prefill for DFlash spec decode. + + Default walks the prompt through the per-position S=1 captured + graphs, which writes the BF16 KV cache the default verify + forward reads. Subclasses whose verify attends over a different + KV store (Thor: FP8-KV) override this with a prefill that + populates that store. Returns the (1, 1) first greedy token. + """ + prompt_len = int(input_ids.shape[1]) + for p in range(prompt_len): + self._static_token_id.copy_(input_ids[:, p:p + 1]) + g_pf = self._ensure_graph_for_pos_nvfp4(p) + self._replay_pos_graph(g_pf, p) + return self._logits_buf.argmax(dim=-1, keepdim=True).view(1, 1) + + @staticmethod + def _dflash_relaxed_matches(logits_K, drafts, all_argmax, + topk: int, delta: float, close_id: int): + """Relaxed draft acceptance for the thinking phase. + + A draft row is accepted when its token is inside the verify + logits' top-``topk`` AND within ``delta`` of the argmax logit + (a raw-logit margin equals a log-prob margin). Rows from the + first draft that closes the think block fall back to strict + argmax matching so the visible answer stays exact-verified. + Returns a 0/1 tensor of shape (K,). + """ + import torch + + K = int(drafts.shape[0]) + topv, topi = torch.topk(logits_K, topk, dim=-1) + ok = ( + (topi == drafts.view(K, 1)) + & ((topv[:, :1] - topv) <= delta) + ).any(-1).long() + close_mask = drafts == close_id + if bool(close_mask.any().item()): + idx = int(close_mask.nonzero()[0].item()) + strict = (all_argmax[:K] == drafts).long() + ok[idx:] = strict[idx:] + return ok + + def _dflash_window_commit(self, N: int) -> None: + """Append the committed rows' features to the per-token window. + + Tap rows 0..N are the state-advanced verify rows + [tok, drafts[:N]], oldest first. Callers must invoke this + BEFORE the end-of-cycle taps[:, 0] shuffle, which overwrites + row 0 with row N. + """ + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( + pertoken_window_append, + ) + + R = N + 1 + rows = self._dflash_buf['pt_taps_rows'][:R] + rows.copy_(self._dflash_taps_buf[:, :R].permute(1, 0, 2)) + pertoken_window_append(self, rows) + + def _dflash_snap_state(self, cur_pos: int, Kv: int) -> None: + """Arch hook: snapshot state the partial-accept rollback needs. + + Runs on the snap stream, overlapped with the drafter forward. + Subclasses whose rollback reads per-step state checkpoints + written during the verify itself (Thor) override this with a + no-op. + """ + self._snap_lin_buf.copy_(self._lin_state) + self._snap_conv_buf.copy_(self._lin_conv_state) + self._snap_K_buf[:, :Kv].copy_( + self._attn.K_cache[:, cur_pos:cur_pos + Kv]) + self._snap_V_buf[:, :Kv].copy_( + self._attn.V_cache[:, cur_pos:cur_pos + Kv]) + + def _dflash_partial_rollback(self, cur_pos: int, N: int, Kv: int, + tok, drafts, cos_KN, sin_KN) -> None: + """Arch hook: fix up state after a partial accept of N drafts. + + On exit the recurrent/conv state and KV must reflect exactly + the N+1 committed rows [tok, drafts[:N]] at + [cur_pos, cur_pos+N+1), and ``_dflash_taps_buf[:, N]`` must + hold the taps of the last committed row. + + Default: restore the pre-verify snapshot, then re-advance with + the committed rows via a tapped verify at K=N+1 (a second + main-model forward). Subclasses with per-step state saves in + the verify K-row (Thor) override this with constant-time state + copies instead. + """ + import torch + + self._lin_state.copy_(self._snap_lin_buf) + self._lin_conv_state.copy_(self._snap_conv_buf) + self._attn.K_cache[:, cur_pos:cur_pos + Kv].copy_( + self._snap_K_buf[:, :Kv]) + self._attn.V_cache[:, cur_pos:cur_pos + Kv].copy_( + self._snap_V_buf[:, :Kv]) + + Kr = N + 1 + self._verify_static_tokens[:, 0:1].copy_(tok) + if N > 0: + self._verify_static_tokens[:, 1:Kr].copy_( + drafts[:N].view(1, N)) + self._verify_static_cos[:, :Kr].copy_(cos_KN[:, :Kr]) + self._verify_static_sin[:, :Kr].copy_(sin_KN[:, :Kr]) + rg = self._ensure_verify_graph_dflash_nvfp4(cur_pos, Kr) + gs = self._graph_stream + gs.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(gs): + rg.replay() + torch.cuda.current_stream().wait_stream(gs) + + def _ensure_drafter_graph_dflash_pertoken(self): + """Lazy CUDA Graph for the per-token-window drafter forward. + + The forward is read-only over the window (updates happen + outside the graph via ``pertoken_window_append``), so capture + needs no state snapshot/restore. One graph per frontend — the + window length is fixed at alloc time. + """ + import torch + + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( + dflash_drafter_forward_pertoken, + ) + + g = getattr(self, '_captured_drafter_graph_pertoken', None) + if g is not None: + return g + + gs = self._graph_stream + gs.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(gs), torch.no_grad(): + for _ in range(2): + dflash_drafter_forward_pertoken(self) + gs.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph( + g, stream=gs, pool=self._graph_mempool, + ), torch.no_grad(): + dflash_drafter_forward_pertoken(self) + gs.synchronize() + torch.cuda.current_stream().wait_stream(gs) + self._captured_drafter_graph_pertoken = g + return g + def _ensure_verify_graph_dflash_nvfp4(self, cur_pos: int, K: int): - """Lazy CUDA Graph for forward_own_decode_K_nvfp4 WITH tap_buf. + """Lazy CUDA Graph for the DFlash verify forward WITH tap_buf. Mirror of ``_ensure_verify_graph_nvfp4`` but binds ``tap_buf=self._dflash_taps_buf`` at capture time so the 5 @@ -11777,9 +11944,9 @@ def _restore(): sin_K = self._verify_static_sin[:, :K] tap_buf = self._dflash_taps_buf for _ in range(2): - self.forward_own_decode_K_nvfp4( - tokens_K, cos_K, sin_K, cur_pos, K=K, - tap_buf=tap_buf) + self._dflash_verify_forward_K( + tokens_K, cos_K, sin_K, cur_pos, K, + tap_buf) _restore() gs.synchronize() @@ -11787,9 +11954,9 @@ def _restore(): with torch.cuda.graph( g, stream=gs, pool=self._graph_mempool, ), torch.no_grad(): - self.forward_own_decode_K_nvfp4( - tokens_K, cos_K, sin_K, cur_pos, K=K, - tap_buf=tap_buf) + self._dflash_verify_forward_K( + tokens_K, cos_K, sin_K, cur_pos, K, + tap_buf) with torch.cuda.stream(gs), torch.no_grad(): _restore() gs.synchronize() @@ -11849,20 +12016,55 @@ def generate_own_speculative_DFlash_nvfp4( eff_ctx = int(getattr(self, '_dflash_eff_ctx', 16)) alloc_drafter_capture_window(self, eff_ctx) reset_drafter_capture_state(self) + # Per-token window mode: the drafter attends to fc-projected + # features of every committed token instead of one entry per + # spec cycle. The prefill hook may seed the window from the + # prompt tail. + pertoken = bool(getattr(self, '_dflash_pertoken_window', False)) + if pertoken: + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 + alloc_pertoken_window, + reset_pertoken_window, + ) + alloc_pertoken_window( + self, int(getattr(self, '_dflash_pertoken_win', 128))) + reset_pertoken_window(self) + # Relaxed acceptance for the thinking phase (opt-in; mirrors + # the TensorRT-LLM MTP policy): inside a block a draft + # is accepted when it is in the verify logits' top-k AND within + # a logit margin of the argmax; the accepted token is then the + # DRAFT (rows already condition on drafts, so state/KV stay + # consistent). Rows from the first draft that closes the think + # block fall back to strict argmax matching. Default off — the + # strict path is byte-identical with this disabled. + relaxed = None + if os.environ.get( + 'FLASHRT_QWEN36_DFLASH_RELAXED_THINKING', '0', + ).strip().lower() in ('1', 'true', 'on'): + think_open = self._tokenizer.convert_tokens_to_ids('') + think_close = self._tokenizer.convert_tokens_to_ids('') + if isinstance(think_open, int) and think_open >= 0: + relaxed = { + 'topk': max(1, int(os.environ.get( + 'FLASHRT_QWEN36_DFLASH_RELAXED_TOPK', '3'))), + 'delta': float(os.environ.get( + 'FLASHRT_QWEN36_DFLASH_RELAXED_DELTA', '1.0')), + 'open': int(think_open), + 'close': int(think_close), + } + # The chat template opens the think block at the end of the + # generation prompt, so the phase can start active. + in_think = bool( + relaxed is not None + and relaxed['open'] in input_ids[0, -8:].tolist()) # Initialize taps to zero — first drafter call gets no real # signal; AL on cycle 0 will be lower than steady-state. self._dflash_taps_buf.zero_() with torch.no_grad(): - # 1) Prefill (same as MTP path) — sequential S=1 forwards - # via the per-cur_pos captured S=1 graph. - gs_pf = self._graph_stream - for p in range(prompt_len): - self._static_token_id.copy_(input_ids[:, p:p + 1]) - g_pf = self._ensure_graph_for_pos_nvfp4(p) - self._replay_pos_graph(g_pf, p) - tok = self._logits_buf.argmax( - dim=-1, keepdim=True).view(1, 1) + # 1) Prefill via the arch hook (default: sequential S=1 + # forwards through the per-cur_pos captured graphs). + tok = self._dflash_prefill_nvfp4(input_ids) generated = [tok] cur_pos = prompt_len @@ -11879,14 +12081,7 @@ def generate_own_speculative_DFlash_nvfp4( snap_stream = self._snap_stream snap_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(snap_stream): - self._snap_lin_buf.copy_(self._lin_state) - self._snap_conv_buf.copy_(self._lin_conv_state) - self._snap_K_buf[:, :Kv].copy_( - self._attn.K_cache[ - :, cur_pos:cur_pos + Kv]) - self._snap_V_buf[:, :Kv].copy_( - self._attn.V_cache[ - :, cur_pos:cur_pos + Kv]) + self._dflash_snap_state(cur_pos, Kv) # 2b) Drafter forward (P7). # Caller writes static inputs (prev_token + hidden_taps). @@ -11896,15 +12091,29 @@ def generate_own_speculative_DFlash_nvfp4( # (avoids zero-dilution that hurts AL). Once the window # is full, replay the captured graph. self._dflash_buf['ids_static'][0:1].copy_(tok.view(1)) - self._dflash_buf['hidden_taps_static'].copy_( - self._dflash_taps_buf[:, 0]) - if self._spec_attempts < eff_ctx: + if pertoken: + valid = int(self._dflash_buf['pt_valid']) + if valid < int(self._dflash_buf['pt_win']): + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 + dflash_drafter_forward_pertoken, + ) + dflash_drafter_forward_pertoken( + self, max(1, valid)) + else: + drafter_g = ( + self._ensure_drafter_graph_dflash_pertoken()) + drafter_g.replay() + elif self._spec_attempts < eff_ctx: from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( # noqa: E501 dflash_drafter_forward_capture_eager, ) + self._dflash_buf['hidden_taps_static'].copy_( + self._dflash_taps_buf[:, 0]) valid_ctx = self._spec_attempts + 1 dflash_drafter_forward_capture_eager(self, valid_ctx) else: + self._dflash_buf['hidden_taps_static'].copy_( + self._dflash_taps_buf[:, 0]) drafter_g = self._ensure_drafter_graph_dflash_nvfp4( eff_ctx) drafter_g.replay() @@ -11940,7 +12149,14 @@ def generate_own_speculative_DFlash_nvfp4( # 2d) Argmax + accept-prefix all_argmax = logits_KN.argmax(dim=-1) # (Kv,) long - matches = (all_argmax[:K] == drafts).long() + relaxed_cycle = relaxed is not None and in_think + if relaxed_cycle: + matches = self._dflash_relaxed_matches( + logits_KN[:K], drafts, all_argmax, + relaxed['topk'], relaxed['delta'], + relaxed['close']) + else: + matches = (all_argmax[:K] == drafts).long() matches_pad = torch.cat([ matches, torch.zeros(1, device=matches.device, @@ -11951,53 +12167,48 @@ def generate_own_speculative_DFlash_nvfp4( self._spec_accepts += N argmax_at = (lambda j: all_argmax[j:j + 1].view(1, 1)) + if relaxed_cycle: + # Accepted rows commit the DRAFT token (the verify + # rows and per-step state condition on the drafts); + # the bonus row commits the argmax as usual. + commit_at = (lambda j: ( + drafts[j:j + 1].view(1, 1) if j < N + else argmax_at(j))) + else: + commit_at = argmax_at if N == K: self._spec_full += 1 for j in range(Kv): if len(generated) < max_new_tokens: - generated.append(argmax_at(j)) + generated.append(commit_at(j)) tok = argmax_at(K) - # Move taps[K] -> taps[0] for next cycle - self._dflash_taps_buf[:, 0].copy_( - self._dflash_taps_buf[:, K]) cur_pos += Kv else: for j in range(N + 1): if len(generated) < max_new_tokens: - generated.append(argmax_at(j)) - # Restore pre-verify state. - self._lin_state.copy_(self._snap_lin_buf) - self._lin_conv_state.copy_(self._snap_conv_buf) - self._attn.K_cache[ - :, cur_pos:cur_pos + Kv].copy_( - self._snap_K_buf[:, :Kv]) - self._attn.V_cache[ - :, cur_pos:cur_pos + Kv].copy_( - self._snap_V_buf[:, :Kv]) - - # Re-advance with N+1 valid inputs via tapped verify - # at K=N+1 (always — including N=0; same code path - # as N>0). Re-uses the dflash verify graph cache. - Kr = N + 1 - rec_cos = cos_KN[:, :Kr] - rec_sin = sin_KN[:, :Kr] - self._verify_static_tokens[:, 0:1].copy_(tok) - if N > 0: - self._verify_static_tokens[:, 1:Kr].copy_( - drafts[:N].view(1, N)) - self._verify_static_cos[:, :Kr].copy_(rec_cos) - self._verify_static_sin[:, :Kr].copy_(rec_sin) - rg = self._ensure_verify_graph_dflash_nvfp4( - cur_pos, Kr) - gs.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(gs): - rg.replay() - torch.cuda.current_stream().wait_stream(gs) + generated.append(commit_at(j)) + self._dflash_partial_rollback( + cur_pos, N, Kv, tok, drafts, cos_KN, sin_KN) tok = argmax_at(N) - self._dflash_taps_buf[:, 0].copy_( - self._dflash_taps_buf[:, N]) - cur_pos += Kr + cur_pos += N + 1 + if relaxed is not None: + ids = (drafts[:N].tolist() if N else []) + ids.append(int(all_argmax[N].item())) + for t in ids: + if t == relaxed['open']: + in_think = True + elif t == relaxed['close']: + in_think = False + if pertoken: + # Must precede the taps[:, 0] shuffle below — it + # reads tap rows 0..N and the shuffle overwrites + # row 0. + self._dflash_window_commit(N) + # Move taps[N] -> taps[0] as the next drafter input + # (N == K on a full accept). + self._dflash_taps_buf[:, 0].copy_( + self._dflash_taps_buf[:, N]) if len(generated) > max_new_tokens: generated = generated[:max_new_tokens] diff --git a/flash_rt/frontends/torch/qwen36_thor.py b/flash_rt/frontends/torch/qwen36_thor.py index 4239b76..c07800a 100644 --- a/flash_rt/frontends/torch/qwen36_thor.py +++ b/flash_rt/frontends/torch/qwen36_thor.py @@ -206,17 +206,58 @@ def _thor_alloc_K_row_scratch(self) -> None: # per-position sub-loop. Bit-exact to running K sequential single- # token forwards (see DESIGN §4.5 for the leaf-kernel set). def _layer_forward_lin_K_nvfp4(self, L, h_in_K, K): + # K <= 7 stays on parent's per-step branch — the production + # MTP spec verify path, untouched. The 8..16 band (DFlash + # verify) defaults to parent as well: greedy parity against + # the MTP reference is anchored to parent-family rounding, and + # a Thor-family verify measurably drifts from it. The opt-in + # chunk-saves route (FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES=1) + # trades that token-exact parity for ~5% lower verify cost + # (chunk kernels + per-step checkpoints in one pass) — for + # deployments gating on task-level quality instead. if K <= self._THOR_K_ROW_FAST_PATH_MAX: return super()._layer_forward_lin_K_nvfp4(L, h_in_K, K) + if K <= self._K_save_max: + if self._thor_lin_chunk_saves_enabled(): + return self._thor_lin_K_forward(L, h_in_K, K) + return super()._layer_forward_lin_K_nvfp4(L, h_in_K, K) if K > self.MAX_Q_SEQ: return self._thor_lin_K_dispatch(L, h_in_K, K) return self._thor_lin_K_forward(L, h_in_K, K) + def _thor_lin_chunk_saves_enabled(self) -> bool: + cached = getattr(self, '_thor_lin_saves_flag', None) + if cached is None: + from flash_rt import flash_rt_kernels as fvk + + cached = ( + hasattr(fvk, 'causal_conv1d_qwen36_update_chunk_saves_bf16') + and hasattr( + fvk, + 'qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16') + and os.environ.get( + 'FLASHRT_QWEN36_THOR_LIN_CHUNK_SAVES', '0', + ).strip().lower() in ('1', 'true', 'on')) + self._thor_lin_saves_flag = cached + return cached + def _layer_forward_full_K_nvfp4( self, L, h_in_K, cos_K, sin_K, cur_pos, K): + # The verify must stay on ONE kernel family end to end: rows + # committed by one family while other rows (or the rollback + # checkpoints) come from another surface the families' + # occasional rounding disagreements as greedy divergence. + # K <= 7 (the production MTP verify) stays on parent. The + # 8..16 band follows the lin dispatch: Thor from-scratch when + # the chunk-saves kernels serve the lin layers, parent + # otherwise — mixing families across layer types measurably + # breaks greedy parity. if K <= self._THOR_K_ROW_FAST_PATH_MAX: return super()._layer_forward_full_K_nvfp4( L, h_in_K, cos_K, sin_K, cur_pos, K) + if K <= self._K_save_max and not self._thor_lin_chunk_saves_enabled(): + return super()._layer_forward_full_K_nvfp4( + L, h_in_K, cos_K, sin_K, cur_pos, K) if K > self.MAX_Q_SEQ: return self._thor_full_K_dispatch( L, h_in_K, cos_K, sin_K, cur_pos, K) @@ -302,12 +343,28 @@ def _thor_lin_K_forward(self, L, h_in_K, K): lin_rank = self._linear_layer_rank(L) conv_state = self._lin_conv_state[lin_rank] conv_out_K = self._K_lin_conv_out[:K] - fvk.causal_conv1d_qwen36_update_chunk_bf16( - out_qkv_K.data_ptr(), int(lw['conv1d_w']), - int(lw['conv1d_b']), - conv_out_K.data_ptr(), conv_state.data_ptr(), - 1, K, 10240, 4, True, s, - ) + # Inside the save-steps range, dump per-step state checkpoints + # for the spec-decode partial-accept rollback (same slots the + # parent per-step branch writes). + save_steps = ( + K <= self._K_save_max and self._thor_lin_chunk_saves_enabled()) + if save_steps: + conv_steps = self._K_lin_conv_state_per_step + fvk.causal_conv1d_qwen36_update_chunk_saves_bf16( + out_qkv_K.data_ptr(), int(lw['conv1d_w']), + int(lw['conv1d_b']), + conv_out_K.data_ptr(), conv_state.data_ptr(), + conv_steps[0, lin_rank].data_ptr(), + conv_steps.stride(0), + 1, K, 10240, 4, True, s, + ) + else: + fvk.causal_conv1d_qwen36_update_chunk_bf16( + out_qkv_K.data_ptr(), int(lw['conv1d_w']), + int(lw['conv1d_b']), + conv_out_K.data_ptr(), conv_state.data_ptr(), + 1, K, 10240, 4, True, s, + ) # (7-9) Fused conv_out -> split + Q/K broadcast + GDN gating # + GDN chunk recurrent in one launch. Replaces three separate @@ -316,15 +373,29 @@ def _thor_lin_K_forward(self, L, h_in_K, K): attn_out_K = self._K_lin_attn_out[:K] a_stride = a_vec_K.stride(0) b_stride = b_vec_K.stride(0) - fvk.qwen36_gdn_chunk_from_conv_smem_strided_bf16( - conv_out_K.data_ptr(), - a_vec_K.data_ptr(), b_vec_K.data_ptr(), - lw['neg_A_log_exp_fp32_t'].data_ptr(), - lw['dt_bias_fp32_t'].data_ptr(), - rec_state.data_ptr(), - attn_out_K.data_ptr(), - K, 48, a_stride, b_stride, True, s, - ) + if save_steps: + lin_steps = self._K_lin_state_per_step + fvk.qwen36_gdn_chunk_from_conv_smem_strided_saves_bf16( + conv_out_K.data_ptr(), + a_vec_K.data_ptr(), b_vec_K.data_ptr(), + lw['neg_A_log_exp_fp32_t'].data_ptr(), + lw['dt_bias_fp32_t'].data_ptr(), + rec_state.data_ptr(), + lin_steps[0, lin_rank].data_ptr(), + lin_steps.stride(0), + attn_out_K.data_ptr(), + K, 48, a_stride, b_stride, True, s, + ) + else: + fvk.qwen36_gdn_chunk_from_conv_smem_strided_bf16( + conv_out_K.data_ptr(), + a_vec_K.data_ptr(), b_vec_K.data_ptr(), + lw['neg_A_log_exp_fp32_t'].data_ptr(), + lw['dt_bias_fp32_t'].data_ptr(), + rec_state.data_ptr(), + attn_out_K.data_ptr(), + K, 48, a_stride, b_stride, True, s, + ) # (10) rms_norm_gated_silu @ M=K*48, dim=128. attn_out_flat = attn_out_K.view(K * 48, 128) @@ -977,8 +1048,11 @@ def _thor_mtp_prefill_K_nvfp4( # NOT enable TurboQuant; FP8-KV is selected by # ``FLASHRT_QWEN36_LONG_KV_CACHE=fp8``. The env is honoured here # for explicit user overrides (bisection, ablation). - def _long_tq_effective_k(self, prompt_len: int, K: int) -> int: - target_k = super()._long_tq_effective_k(prompt_len, K) + def _long_tq_effective_k( + self, prompt_len: int, K: int, + max_new_tokens: int | None = None) -> int: + target_k = super()._long_tq_effective_k( + prompt_len, K, max_new_tokens) if os.environ.get('FLASHRT_QWEN36_TQ_SPEC_K', ''): return target_k if target_k > 6 and int(prompt_len) >= 12288: @@ -1089,3 +1163,158 @@ def _prefill_mtp_tail_kv_nvfp4( return self._thor_mtp_prefill_K_nvfp4( prev_h_rows, token_ids, pos_start, rows, cache_base_pos=cache_base_pos) + + # ---------- DFlash integration ---------- + # + # DFlash verifies at S=block_size (16), above + # ``_THOR_K_ROW_FAST_PATH_MAX``, so the K-row layers route to + # ``_thor_full_K_forward`` / ``_thor_lin_K_forward`` — and the + # full-attn K-row is single-XQA-path over the persistent FP8 KV + # cache. Three consequences, each handled by one override below: + # the drafter load must guarantee the FP8 cache exists, the prompt + # prefill must populate it, and the verify forward must run with + # the FP8-KV mode flag active. + + def _load_dflash_drafter(self, ckpt_dir: str | None = None) -> None: + import torch + + super()._load_dflash_drafter(ckpt_dir) + # Short-ctx constructions (user_max_seq <= LONG_CTX_THRESHOLD) + # never allocate the persistent FP8 KV cache; the Thor DFlash + # verify cannot run without it. + if not hasattr(self, '_fp8_K_cache'): + self._load_fp8_kv_cache(max_seq=self._user_max_seq + 16) + self._long_kv_cache_mode = 'fp8' + # Grow the per-step state checkpoints to the DFlash verify + # q_seq (block_size = _MAX_PUBLIC_SPEC_K + 1). The lin K-row + # save-steps branch then covers the whole verify, and the + # partial-accept rollback becomes two constant-time copies + # instead of a second main-model forward. + needed = self._MAX_PUBLIC_SPEC_K + 1 + if self._K_save_max < needed: + self._K_save_max = needed + self._K_lin_state_per_step = torch.empty( + needed, *self._lin_state.shape, + device=self._lin_state.device, + dtype=self._lin_state.dtype) + self._K_lin_conv_state_per_step = torch.empty( + needed, *self._lin_conv_state.shape, + device=self._lin_conv_state.device, + dtype=self._lin_conv_state.dtype) + # Any K-row graph captured before the grow baked the old + # checkpoint buffers — drop those graphs so they re-capture + # against the new allocations. + for cache_name in ( + '_captured_verify_graphs_fp8kv', + '_captured_prefill_graphs_fp8kv', + '_captured_verify_graphs_tq', + '_captured_prefill_graphs_tq', + '_captured_verify_graphs_dflash', + ): + cache = getattr(self, cache_name, None) + if cache: + cache.clear() + # Per-token drafter window (default on for Thor): the drafter + # attends to fc-projected features of every committed token. + # Measured on Thor at ctx=128: steady AL 2.53 -> 3.49 vs the + # one-entry-per-cycle shift window. + if not hasattr(self, '_dflash_pertoken_window'): + self._dflash_pertoken_window = os.environ.get( + 'FLASHRT_QWEN36_DFLASH_PERTOKEN', '1', + ).strip().lower() not in ('0', 'false', 'off') + self._dflash_pertoken_win = int(os.environ.get( + 'FLASHRT_QWEN36_DFLASH_WINDOW', '128') or '128') + + def _dflash_prefill_nvfp4(self, input_ids): + """Thor override: chunked FP8-KV prompt prefill. + + The default per-position walk writes only the BF16 KV cache; + the Thor verify attends over the FP8 cache, so the prompt rows + must land there. The chunked prefill is also the production + Thor TTFT path (batched XQA instead of one forward per token). + + In per-token-window mode the last min(window, prompt) tokens + run as a separate tap-captured chunk so the drafter window + starts seeded with the prompt tail's features instead of + ramping from empty. + """ + seed_window = ( + getattr(self, '_dflash_pertoken_window', False) + and os.environ.get( + 'FLASHRT_QWEN36_DFLASH_WINDOW_SEED', '1', + ).strip().lower() not in ('0', 'false', 'off')) + if not seed_window: + _, logits = self._prefill_long_ctx_tq_chunked(input_ids) + return logits.argmax(dim=-1, keepdim=True).view(1, 1) + + from flash_rt.frontends.torch._qwen36_rtx_dflash_forward import ( + alloc_pertoken_window, + pertoken_window_append, + ) + + alloc_pertoken_window( + self, int(getattr(self, '_dflash_pertoken_win', 128))) + buf = self._dflash_buf + P = int(input_ids.shape[1]) + tail = min(int(buf['pt_win']), P) + if P > tail: + self._prefill_long_ctx_tq_chunked(input_ids[:, :P - tail]) + d = self._rope_dim + cos_T = self._rope_cos_table[P - tail:P].view(1, tail, d) + sin_T = self._rope_sin_table[P - tail:P].view(1, tail, d) + seed = buf['pt_seed_taps'] + logits = self.forward_own_decode_K_nvfp4_fp8kv( + input_ids[:, P - tail:], cos_T, sin_T, P - tail, tail, + tap_buf=seed, logits_mode='last') + rows = buf['pt_taps_rows'][:tail] + rows.copy_(seed[:, :tail].permute(1, 0, 2)) + pertoken_window_append(self, rows) + return logits.argmax(dim=-1, keepdim=True).view(1, 1) + + def _dflash_verify_forward_K(self, token_ids_K, cos_K, sin_K, + cur_pos: int, K: int, tap_buf): + """Thor override: run the DFlash verify in FP8-KV mode. + + Same wrapper as the production long-ctx spec verify, so the + K-row layer dispatch sees ``_fp8_kv_verify_active`` for the + whole S=K forward. + """ + return self.forward_own_decode_K_nvfp4_fp8kv( + token_ids_K, cos_K, sin_K, cur_pos, K, tap_buf=tap_buf) + + def _dflash_snap_state(self, cur_pos: int, Kv: int) -> None: + """Thor override: nothing to snapshot. + + The rollback reads the per-step checkpoints written during the + verify K-row itself. The Thor verify never writes the BF16 KV + cache, and FP8 rows past the accept point are overwritten by + the next verify before any read. + """ + return + + def _dflash_partial_rollback(self, cur_pos: int, N: int, Kv: int, + tok, drafts, cos_KN, sin_KN) -> None: + """Thor override: constant-time state rollback. + + The verify at S=Kv ran the lin K-row save-steps branch + (``Kv <= _K_save_max`` after drafter load), so the state after + every verify row is checkpointed; committing N drafts is a copy + from slot N. Same pattern as the long-ctx MTP spec loop. Taps + for rows <= N are already in ``_dflash_taps_buf`` from the main + verify. + """ + import torch + + from flash_rt import flash_rt_kernels as fvk + + s = torch.cuda.current_stream().cuda_stream + fvk.gpu_copy( + self._lin_state.data_ptr(), + self._K_lin_state_per_step[N].data_ptr(), + self._lin_state.numel() * 2, s, + ) + fvk.gpu_copy( + self._lin_conv_state.data_ptr(), + self._K_lin_conv_state_per_step[N].data_ptr(), + self._lin_conv_state.numel() * 2, s, + ) diff --git a/serving/qwen36_dflash_agent/README.md b/serving/qwen36_dflash_agent/README.md new file mode 100644 index 0000000..e7acc9f --- /dev/null +++ b/serving/qwen36_dflash_agent/README.md @@ -0,0 +1,90 @@ +# serving/qwen36_dflash_agent + +OpenAI-compatible serving host for Qwen3.6-27B NVFP4 with **DFlash +block-diffusion speculative decoding** +(see [`docs/qwen36_dflash.md`](../../docs/qwen36_dflash.md)). + +This directory is the policy layer above the FlashRT execution +contract: it owns request shaping and telemetry only, adds no session +or KV verbs to `exec/`, and keeps the frontend API untouched. + +## Scope + +| | this host | [`serving/qwen36_agent`](../qwen36_agent) | +|---|---|---| +| decode path | DFlash drafter (K=15 block) | MTP chain (K<=6) | +| session state | stateless — full prefill per request | exact-prefix reuse, capsules | +| tool calling / SSE streaming | no | yes | +| concurrency | batch 1, serialized | batch 1, scheduled sessions | + +Use this host for single-stream, short-context request/response +workloads (robot planners, structured-output services) where the +DFlash path measures fastest; use `qwen36_agent` for long-running +agent sessions. + +## Quickstart + +**Prerequisites**: FlashRT built for your GPU (`GPU_ARCH=110` on +Jetson AGX Thor), the Qwen3.6-27B NVFP4 checkpoint, the paired FP8 +MTP checkpoint (frontend construction requires it), and the DFlash +drafter checkpoint: + +```bash +hf download z-lab/Qwen3.6-27B-DFlash --local-dir /models/Qwen3.6-27B-DFlash +pip install fastapi uvicorn +``` + +**1. Start the server** + +```bash +export FLASHRT_QWEN36_MTP_CKPT_DIR=/models/Qwen3.6-27B-FP8 +export FLASHRT_QWEN36_DFLASH_CKPT_DIR=/models/Qwen3.6-27B-DFlash +export FLASHRT_QWEN36_LONG_KV_CACHE=fp8 + +python -m serving.qwen36_dflash_agent.server \ + --checkpoint /models/Qwen3.6-27B-NVFP4 \ + --max-seq 32768 --K 15 \ + --host 127.0.0.1 --port 8000 +``` + +The frontend arch is auto-detected (SM110 -> Thor, otherwise RTX); +override with `--arch thor|rtx`. + +**2. Check it is up** + +```bash +curl -s http://127.0.0.1:8000/health +# {"status":"ok","arch":"thor","path":"dflash","pertoken_window":true,...} +``` + +**3. Chat completion** + +```bash +curl -s http://127.0.0.1:8000/v1/chat/completions \ + -H 'Content-Type: application/json' -d '{ + "model": "qwen3.6-27b-dflash", + "messages": [{"role": "user", "content": + "Output a JSON action list to pick up the red cube and place it on the tray."}], + "max_tokens": 256 + }' +``` + +The response carries a `flashrt` telemetry block with the speculation +cycle count, realized accept length, and end-to-end latency. + +## Limits (v1) + +- Greedy decode only; sampling parameters are accepted and ignored. +- `stream` is not supported; responses return complete. +- The DFlash loop generates the full `max_tokens` budget and the + response is truncated at the first end token — budget generously + but not extravagantly. +- Qwen thinking mode is off by default; pass `"enable_thinking": true` + to opt in. + +## Tuning + +DFlash env knobs (`FLASHRT_QWEN36_DFLASH_PERTOKEN`, `..._WINDOW`, +`..._WINDOW_SEED`) are documented in +[`docs/qwen36_dflash.md`](../../docs/qwen36_dflash.md) together with +measured Thor performance. diff --git a/serving/qwen36_dflash_agent/__init__.py b/serving/qwen36_dflash_agent/__init__.py new file mode 100644 index 0000000..43c5a26 --- /dev/null +++ b/serving/qwen36_dflash_agent/__init__.py @@ -0,0 +1 @@ +"""OpenAI-compatible serving host for Qwen3.6-27B DFlash spec decode.""" diff --git a/serving/qwen36_dflash_agent/server.py b/serving/qwen36_dflash_agent/server.py new file mode 100644 index 0000000..76728be --- /dev/null +++ b/serving/qwen36_dflash_agent/server.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +"""FlashRT — Qwen3.6-27B DFlash OpenAI-compatible serving host. + +Serves /v1/chat/completions backed by the DFlash block-diffusion +speculative-decode path (`generate_own_speculative_DFlash_nvfp4`). +This is the policy layer above the FlashRT execution contract: it owns +request shaping and telemetry only, and adds no session or KV verbs. + +Scope (v1): + * Stateless per request — every call prefills the full prompt. + For long-running agent sessions with prefix reuse, tool calling, + and committed-token streaming, use ``serving/qwen36_agent``. + * Batch size 1; concurrent requests are serialized on one GPU. + * Greedy decode only — sampling parameters are accepted and ignored. + * The DFlash loop generates the full ``max_tokens`` budget; the + response is truncated at the first end token during detokenize. + +Usage: + pip install fastapi uvicorn + + export FLASHRT_QWEN36_MTP_CKPT_DIR=/models/Qwen3.6-27B-FP8 + export FLASHRT_QWEN36_DFLASH_CKPT_DIR=/models/Qwen3.6-27B-DFlash + export FLASHRT_QWEN36_LONG_KV_CACHE=fp8 + + python -m serving.qwen36_dflash_agent.server \\ + --checkpoint /models/Qwen3.6-27B-NVFP4 \\ + --max-seq 32768 --K 15 --port 8000 +""" +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import time +import uuid +from typing import Any, Dict, List, Optional + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', +) +log = logging.getLogger('qwen36_dflash_server') + + +def _build_frontend(args): + import torch + + cap = torch.cuda.get_device_capability() + arch = args.arch + if arch == 'auto': + arch = 'thor' if cap == (11, 0) else 'rtx' + if arch == 'thor': + from flash_rt.frontends.torch.qwen36_thor import ( + Qwen36TorchFrontendThor as Frontend, + ) + else: + from flash_rt.frontends.torch.qwen36_rtx import ( + Qwen36TorchFrontendRtx as Frontend, + ) + log.info('loading %s frontend (sm %s), checkpoint=%s', + arch, cap, args.checkpoint) + fe = Frontend(args.checkpoint, quant='nvfp4', max_seq=args.max_seq) + fe.init_dflash_drafter(args.dflash_checkpoint or None) + log.info('DFlash drafter ready (pertoken=%s window=%s)', + getattr(fe, '_dflash_pertoken_window', False), + getattr(fe, '_dflash_pertoken_win', None)) + return fe, arch + + +def _chat_ids(fe, messages: List[Dict[str, Any]], enable_thinking: bool): + return fe._tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + enable_thinking=enable_thinking, + return_tensors='pt', + ).to(fe.device) + + +def create_app(args): + from fastapi import FastAPI, HTTPException + + fe, arch = _build_frontend(args) + tok = fe._tokenizer + end_ids = {tid for tid in ( + tok.eos_token_id, + tok.convert_tokens_to_ids('<|im_end|>'), + ) if isinstance(tid, int) and tid >= 0} + + app = FastAPI(title='FlashRT Qwen3.6 DFlash server') + gpu_lock = asyncio.Lock() + state = {'requests': 0} + + @app.get('/health') + async def health(): + return { + 'status': 'ok', + 'arch': arch, + 'path': 'dflash', + 'max_seq': args.max_seq, + 'K': args.K, + 'pertoken_window': bool( + getattr(fe, '_dflash_pertoken_window', False)), + 'window': getattr(fe, '_dflash_pertoken_win', None), + 'requests_served': state['requests'], + } + + @app.get('/v1/models') + async def models(): + return {'object': 'list', 'data': [{ + 'id': args.model_name, 'object': 'model', + 'owned_by': 'flashrt'}]} + + @app.post('/v1/chat/completions') + async def chat(body: Dict[str, Any]): + import torch + + messages = body.get('messages') + if not messages: + raise HTTPException(400, 'messages is required') + max_tokens = int(body.get('max_tokens') or args.default_max_tokens) + max_tokens = max(1, min(max_tokens, args.max_tokens_cap)) + enable_thinking = bool(body.get('enable_thinking', False)) + + async with gpu_lock: + t0 = time.perf_counter() + ids = _chat_ids(fe, messages, enable_thinking) + prompt_len = int(ids.shape[1]) + if prompt_len + max_tokens > args.max_seq: + raise HTTPException( + 400, f'prompt ({prompt_len}) + max_tokens ' + f'({max_tokens}) exceeds max_seq ({args.max_seq})') + out = await asyncio.to_thread( + fe.generate_own_speculative_DFlash_nvfp4, + ids, max_new_tokens=max_tokens, K=args.K) + torch.cuda.synchronize() + dt = time.perf_counter() - t0 + + new_ids = out[0, prompt_len:].tolist() + for i, t in enumerate(new_ids): + if t in end_ids: + new_ids = new_ids[:i] + break + text = tok.decode(new_ids, skip_special_tokens=True) + attempts = int(getattr(fe, '_spec_attempts', 0)) + state['requests'] += 1 + return { + 'id': f'chatcmpl-{uuid.uuid4().hex[:24]}', + 'object': 'chat.completion', + 'created': int(time.time()), + 'model': args.model_name, + 'choices': [{ + 'index': 0, + 'message': {'role': 'assistant', 'content': text}, + 'finish_reason': ( + 'stop' if len(new_ids) < max_tokens else 'length'), + }], + 'usage': { + 'prompt_tokens': prompt_len, + 'completion_tokens': len(new_ids), + 'total_tokens': prompt_len + len(new_ids), + }, + 'flashrt': { + 'path': 'dflash', + 'spec_cycles': attempts, + 'accept_length': ( + round(len(new_ids) / attempts, 2) if attempts else None), + 'e2e_ms': round(dt * 1e3, 1), + }, + } + + return app + + +def main() -> int: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument('--checkpoint', required=True, + help='Qwen3.6-27B NVFP4 checkpoint directory') + p.add_argument('--dflash-checkpoint', default='', + help='DFlash drafter directory (default: ' + 'FLASHRT_QWEN36_DFLASH_CKPT_DIR)') + p.add_argument('--model-name', default='qwen3.6-27b-dflash') + p.add_argument('--arch', choices=['auto', 'thor', 'rtx'], + default='auto') + p.add_argument('--max-seq', type=int, default=32768) + p.add_argument('--K', type=int, default=15, + help='speculative tokens per cycle (block_size - 1)') + p.add_argument('--default-max-tokens', type=int, default=256) + p.add_argument('--max-tokens-cap', type=int, default=4096) + p.add_argument('--host', default='127.0.0.1') + p.add_argument('--port', type=int, default=8000) + args = p.parse_args() + + os.environ.setdefault('FLASHRT_QWEN36_LONG_KV_CACHE', 'fp8') + + import uvicorn + uvicorn.run(create_app(args), host=args.host, port=args.port) + return 0 + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/tests/test_qwen36_dflash_structural.py b/tests/test_qwen36_dflash_structural.py new file mode 100644 index 0000000..aef327a --- /dev/null +++ b/tests/test_qwen36_dflash_structural.py @@ -0,0 +1,198 @@ +"""Structural tests for the Qwen3.6 DFlash spec-decode path. + +These run without model checkpoints or a GPU: they validate the +contracts that hardware benchmarks cannot guard cheaply — + + * the per-token window commit reads tap rows 0..N BEFORE the + end-of-cycle taps[:, 0] shuffle overwrites row 0, and stores a + copy (later tap mutation must not alias into the window input); + * the spec-decode loop keeps that ordering (source-order guard); + * generate fails fast with a clear error when no drafter is loaded; + * the public ``init_dflash_drafter`` wrapper delegates to the + loader; + * Thor's per-token window env routing (default on, opt-out, window + length override). + +GPU/end-to-end evidence for this path lives in the hardware-gated +benchmarks; see docs/qwen36_dflash.md. +""" + +from __future__ import annotations + +import inspect + +import pytest + +torch = pytest.importorskip("torch") + +from flash_rt.frontends.torch import _qwen36_rtx_dflash_forward as dff # noqa: E402 +from flash_rt.frontends.torch.qwen36_rtx import ( # noqa: E402 + Qwen36TorchFrontendRtx, +) +from flash_rt.frontends.torch.qwen36_thor import ( # noqa: E402 + Qwen36TorchFrontendThor, +) + + +HIDDEN = 8 +KV = 16 + + +def _stub_rtx(): + fe = Qwen36TorchFrontendRtx.__new__(Qwen36TorchFrontendRtx) + taps = torch.zeros(5, KV, HIDDEN) + for row in range(KV): + taps[:, row] = row + 1 + fe._dflash_taps_buf = taps + fe._dflash_buf = { + "pt_taps_rows": torch.zeros(KV, 5, HIDDEN), + } + return fe + + +def test_window_commit_reads_rows_before_shuffle(monkeypatch): + fe = _stub_rtx() + seen = [] + monkeypatch.setattr( + dff, "pertoken_window_append", + lambda frontend, rows: seen.append(rows)) + + N = 3 + fe._dflash_window_commit(N) + + assert len(seen) == 1 + rows = seen[0] + assert rows.shape == (N + 1, 5, HIDDEN) + # Row order: oldest committed row first, values 1..N+1 per the + # stub filling — row 0 must be the ORIGINAL row 0, not row N. + expect = torch.tensor([1.0, 2.0, 3.0, 4.0]) + assert torch.equal(rows[:, 0, 0], expect) + + # The end-of-cycle shuffle overwrites tap row 0 with row N; the + # committed rows must be a copy, not a view into the tap buffer. + fe._dflash_taps_buf[:, 0].copy_(fe._dflash_taps_buf[:, N]) + assert torch.equal(rows[:, 0, 0], expect) + + +def test_window_commit_full_accept_covers_all_rows(monkeypatch): + fe = _stub_rtx() + seen = [] + monkeypatch.setattr( + dff, "pertoken_window_append", + lambda frontend, rows: seen.append(rows)) + + fe._dflash_window_commit(KV - 1) + assert seen[0].shape[0] == KV + assert torch.equal( + seen[0][:, 0, 0], torch.arange(1.0, KV + 1)) + + +def test_generate_loop_commits_window_before_tap_shuffle(): + src = inspect.getsource( + Qwen36TorchFrontendRtx.generate_own_speculative_DFlash_nvfp4) + commit = src.index("_dflash_window_commit") + shuffle = src.index( + "_dflash_taps_buf[:, 0].copy_", commit) + assert commit < shuffle, ( + "the per-token window must be committed before the taps[:, 0] " + "shuffle overwrites row 0") + + +def test_generate_fails_fast_without_drafter(): + fe = Qwen36TorchFrontendRtx.__new__(Qwen36TorchFrontendRtx) + + class _Weights: + ptrs = {} + + fe._weights = _Weights() + with pytest.raises(RuntimeError, match="DFlash drafter not loaded"): + fe.generate_own_speculative_DFlash_nvfp4( + torch.zeros(1, 4, dtype=torch.long), max_new_tokens=4) + + +def test_public_drafter_init_delegates(monkeypatch): + fe = Qwen36TorchFrontendRtx.__new__(Qwen36TorchFrontendRtx) + calls = [] + monkeypatch.setattr( + Qwen36TorchFrontendRtx, "_load_dflash_drafter", + lambda self, ckpt_dir=None: calls.append(ckpt_dir)) + fe.init_dflash_drafter("/tmp/ckpt") + assert calls == ["/tmp/ckpt"] + + +def _thor_drafter_load(monkeypatch): + """Run Thor's _load_dflash_drafter with the base loader stubbed.""" + monkeypatch.setattr( + Qwen36TorchFrontendRtx, "_load_dflash_drafter", + lambda self, ckpt_dir=None: None) + fe = Qwen36TorchFrontendThor.__new__(Qwen36TorchFrontendThor) + fe._fp8_K_cache = torch.zeros(1) # skip FP8 cache allocation + fe._K_save_max = 16 # skip checkpoint-buffer grow + fe._MAX_PUBLIC_SPEC_K = 15 + fe._load_dflash_drafter() + return fe + + +def test_thor_pertoken_default_on(monkeypatch): + monkeypatch.delenv("FLASHRT_QWEN36_DFLASH_PERTOKEN", raising=False) + monkeypatch.delenv("FLASHRT_QWEN36_DFLASH_WINDOW", raising=False) + fe = _thor_drafter_load(monkeypatch) + assert fe._dflash_pertoken_window is True + assert fe._dflash_pertoken_win == 128 + + +def test_thor_pertoken_env_opt_out(monkeypatch): + monkeypatch.setenv("FLASHRT_QWEN36_DFLASH_PERTOKEN", "0") + fe = _thor_drafter_load(monkeypatch) + assert fe._dflash_pertoken_window is False + + +def test_thor_pertoken_window_env_override(monkeypatch): + monkeypatch.delenv("FLASHRT_QWEN36_DFLASH_PERTOKEN", raising=False) + monkeypatch.setenv("FLASHRT_QWEN36_DFLASH_WINDOW", "64") + fe = _thor_drafter_load(monkeypatch) + assert fe._dflash_pertoken_win == 64 + + +def _relaxed(logits, drafts, topk=3, delta=1.0, close_id=99): + all_argmax = logits.argmax(dim=-1) + return Qwen36TorchFrontendRtx._dflash_relaxed_matches( + logits, drafts, all_argmax, topk, delta, close_id) + + +def test_relaxed_accepts_topk_within_margin(): + # row 0: draft is argmax; row 1: draft is 2nd-best inside margin; + # row 2: draft is 2nd-best OUTSIDE margin; row 3: draft not in topk + logits = torch.tensor([ + [5.0, 1.0, 0.0, 0.0], + [5.0, 4.5, 0.0, 0.0], + [5.0, 2.0, 0.0, 0.0], + [5.0, 4.9, 4.8, 4.7], + ]) + drafts = torch.tensor([0, 1, 1, 3]) + ok = _relaxed(logits, drafts, topk=3, delta=1.0) + assert ok.tolist() == [1, 1, 0, 0] + + +def test_relaxed_strict_after_think_close(): + # row 1 closes the think block -> rows 1+ require exact argmax + logits = torch.tensor([ + [5.0, 4.5, 0.0, 0.0], + [5.0, 4.9, 0.0, 0.0], + [5.0, 4.9, 0.0, 0.0], + ]) + drafts = torch.tensor([1, 2, 1]) # draft row 1 is close_id=2 + ok = _relaxed(logits, drafts, topk=3, delta=1.0, close_id=2) + # row 0 relaxed-accepted; row 1 (close) strict: argmax=0 != 2 -> 0; + # row 2 strict: argmax=0 != 1 -> 0 + assert ok.tolist() == [1, 0, 0] + + +def test_relaxed_strict_rows_match_argmax(): + logits = torch.tensor([ + [5.0, 4.5, 0.0], + [1.0, 6.0, 0.0], + ]) + drafts = torch.tensor([2, 1]) # row 0 closes -> strict from row 0 + ok = _relaxed(logits, drafts, topk=3, delta=10.0, close_id=2) + assert ok.tolist() == [0, 1]