Skip to content

Commit 49c6605

Browse files
unamedkrclaude
andauthored
feat(wasm): chat KV cache reuse — turn N+1 is near-instant in browser (#51)
PR #50 added text-prefix matching to src/engine/tq_generate.c (used by the HTTP server). This PR ports it to quant.h (single-header) so the WASM browser demo and Python wheel get the same speedup. Three layers: 1. **quant.h**: ported tq_generate_chat_text from src/engine. Added cached_text field to quant_ctx struct. quant_chat() now uses the text-prefix path instead of the token-LCP path. quant_free_ctx() frees cached_text. Pass NULL prompt to reset session (frees cached_text too). 2. **wasm/quant_wasm.c**: - wasm_generate_async / wasm_generate now call quant_chat() instead of quant_generate() (which destroyed the cache via free+recreate of g_ctx every call — biggest reason WASM was slow on multi-turn). - Reuse the existing g_ctx across calls; only update temperature/ top_p/max_tokens fields (kv_compress is immutable post-creation). - New wasm_reset_chat() for starting a new chat session. 3. **wasm/index.html**: - Accumulates ChatML history client-side (chatHistory string). Each turn appends `<|im_start|>user\n${text}<|im_end|>\n <|im_start|>assistant\n` and sends the FULL history to WASM. - The C side's text-prefix matching reuses everything before the new turn — turn N's prefill is O(new user message), not O(full history). - After response, appends model output + <|im_end|>\n so the next turn matches the cached_text byte-for-byte. - Loading message differentiates first turn ("Processing prompt — may take a few seconds") vs subsequent ("Generating..."). 4. **wasm/build.sh**: exports _wasm_reset_chat. Validated end-to-end with the C test (real response replay): turn 1: 206 ms (cold, SLOW path) turn 2: 315 ms (FAST text_match=64) turn 5: 437 ms (FAST text_match=321) turn 10: 637 ms (FAST text_match=750) Every turn after the first hits the FAST text-prefix path. The remaining ~50ms/turn growth is the unavoidable O(n) attention cost. For the WASM browser demo, this means: instead of every turn taking full prefill time (5-10s for a 0.8B model), only turn 1 is slow. Turns 2+ feel instantaneous to the user. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 471a5f4 commit 49c6605

6 files changed

Lines changed: 280 additions & 25 deletions

File tree

quant.h

Lines changed: 229 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,10 @@ struct quant_ctx {
17451745
int* cached_tokens;
17461746
int n_cached;
17471747
int cached_capacity;
1748+
/* Text-prefix cache: stores the entire prompt + generated response
1749+
* text from the last call, allowing the next call to bypass BPE
1750+
* re-tokenization issues by matching at the byte level. */
1751+
char* cached_text;
17481752
};
17491753

17501754
// ============================================================================
@@ -15848,6 +15852,225 @@ int tq_generate_continue(tq_model_t* model,
1584815852
return generated;
1584915853
}
1585015854

15855+
/* ============================================================================
15856+
* tq_generate_chat_text — text-prefix matching for chat reuse
15857+
*
15858+
* Solves the BPE re-tokenization issue: when the model generates response
15859+
* tokens via sample_topp, those token IDs may not match what tq_encode()
15860+
* produces from the same response text in the next turn's prompt. The
15861+
* token-level LCP in tq_generate_continue truncates at that boundary.
15862+
*
15863+
* This function tracks the *text* of the last prompt+response. On the next
15864+
* call, if the new prompt starts with cached_text byte-for-byte, the entire
15865+
* cached state is valid — tokenize ONLY the new SUFFIX text and prefill
15866+
* those tokens at positions [n_cached..]. No LCP, no truncation.
15867+
*
15868+
* Pass cached_text_io == NULL to disable text-prefix tracking.
15869+
* ============================================================================ */
15870+
15871+
typedef struct {
15872+
char* buf;
15873+
size_t len;
15874+
size_t cap;
15875+
void (*user_cb)(const char*, void*);
15876+
void* user_data;
15877+
} chat_accum_t;
15878+
15879+
static void chat_accum_callback(const char* tok, void* u) {
15880+
chat_accum_t* ctx = (chat_accum_t*)u;
15881+
if (!tok) return;
15882+
size_t tlen = strlen(tok);
15883+
if (ctx->len + tlen + 1 > ctx->cap) {
15884+
size_t new_cap = (ctx->cap + tlen + 64) * 2;
15885+
char* nb = (char*)realloc(ctx->buf, new_cap);
15886+
if (!nb) return;
15887+
ctx->buf = nb;
15888+
ctx->cap = new_cap;
15889+
}
15890+
memcpy(ctx->buf + ctx->len, tok, tlen);
15891+
ctx->len += tlen;
15892+
ctx->buf[ctx->len] = '\0';
15893+
if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data);
15894+
}
15895+
15896+
int tq_generate_chat_text(tq_model_t* model,
15897+
tq_tokenizer_t* tokenizer,
15898+
tq_state_t* state,
15899+
const char* prompt,
15900+
tq_gen_config_t* config,
15901+
char** cached_text_io,
15902+
int** cached_tokens_io,
15903+
int* n_cached_io,
15904+
int* cached_capacity_io,
15905+
char* output, int output_size) {
15906+
if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt) {
15907+
return -1;
15908+
}
15909+
15910+
int matched_text_len = 0;
15911+
int prefix_pos = 0;
15912+
15913+
if (cached_text_io && *cached_text_io && *n_cached_io > 0) {
15914+
size_t cached_len = strlen(*cached_text_io);
15915+
if (cached_len > 0 && strncmp(*cached_text_io, prompt, cached_len) == 0) {
15916+
matched_text_len = (int)cached_len;
15917+
prefix_pos = *n_cached_io;
15918+
}
15919+
}
15920+
15921+
chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0,
15922+
.user_cb = config->on_token,
15923+
.user_data = config->user_data };
15924+
void (*orig_cb)(const char*, void*) = config->on_token;
15925+
void* orig_ud = config->user_data;
15926+
config->on_token = chat_accum_callback;
15927+
config->user_data = &accum;
15928+
15929+
int generated = 0;
15930+
15931+
if (matched_text_len > 0) {
15932+
const char* suffix = prompt + matched_text_len;
15933+
int max_prompt = model->config.max_seq_len > 0
15934+
? model->config.max_seq_len : 4096;
15935+
int* suffix_toks = (int*)malloc((size_t)max_prompt * sizeof(int));
15936+
if (!suffix_toks) {
15937+
config->on_token = orig_cb; config->user_data = orig_ud;
15938+
return -1;
15939+
}
15940+
int n_suffix = 0;
15941+
if (*suffix != '\0') {
15942+
n_suffix = tq_encode(tokenizer, suffix, suffix_toks, max_prompt, 0);
15943+
if (n_suffix < 0) n_suffix = 0;
15944+
}
15945+
15946+
int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
15947+
if (prefix_pos + n_suffix + reserve + 32 > max_prompt) {
15948+
free(suffix_toks);
15949+
config->on_token = orig_cb; config->user_data = orig_ud;
15950+
*n_cached_io = 0;
15951+
if (cached_text_io && *cached_text_io) {
15952+
free(*cached_text_io); *cached_text_io = NULL;
15953+
}
15954+
int n2 = tq_generate_continue(model, tokenizer, state, prompt, config,
15955+
cached_tokens_io, n_cached_io, cached_capacity_io,
15956+
output, output_size);
15957+
generated = n2;
15958+
goto update_cache;
15959+
}
15960+
15961+
int needed = prefix_pos + n_suffix + reserve + 16;
15962+
if (*cached_capacity_io < needed) {
15963+
int new_cap = needed < 4096 ? 4096 : needed;
15964+
int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int));
15965+
if (!nb) { free(suffix_toks); config->on_token = orig_cb; config->user_data = orig_ud; return -1; }
15966+
*cached_tokens_io = nb;
15967+
*cached_capacity_io = new_cap;
15968+
}
15969+
15970+
int* cached = *cached_tokens_io;
15971+
for (int i = 0; i < n_suffix; i++) {
15972+
cached[prefix_pos + i] = suffix_toks[i];
15973+
tq_forward(model, state, suffix_toks[i], prefix_pos + i);
15974+
}
15975+
*n_cached_io = prefix_pos + n_suffix;
15976+
free(suffix_toks);
15977+
15978+
if (getenv("TQ_CHAT_DEBUG")) {
15979+
fprintf(stderr, "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n",
15980+
matched_text_len, n_suffix);
15981+
}
15982+
15983+
/* Generation loop */
15984+
int vocab_size = model->config.vocab_size;
15985+
int n_cached = *n_cached_io;
15986+
int pos = n_cached;
15987+
int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1;
15988+
15989+
uint64_t rng_state = config->rng_seed
15990+
? (uint64_t)config->rng_seed : (uint64_t)time(NULL);
15991+
int next_token = tq_sample_topp(state->logits, vocab_size,
15992+
config->temperature, config->top_p,
15993+
&rng_state);
15994+
15995+
int output_pos = 0;
15996+
int eos_tokens[] = { 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046 };
15997+
int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]);
15998+
15999+
while (generated < config->max_tokens) {
16000+
int is_eos = 0;
16001+
for (int e = 0; e < n_eos; e++) {
16002+
if (next_token == eos_tokens[e]) { is_eos = 1; break; }
16003+
}
16004+
if (is_eos) break;
16005+
if (pos >= model->config.max_seq_len) break;
16006+
16007+
const char* piece = tokenizer ? tq_decode(tokenizer, prev_token, next_token) : "";
16008+
int should_stop = 0;
16009+
if (piece) {
16010+
if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") ||
16011+
strstr(piece, "<|start_header_id|>")) {
16012+
should_stop = 1; piece = "";
16013+
}
16014+
}
16015+
if (should_stop) break;
16016+
16017+
int piece_len = (int)strlen(piece ? piece : "");
16018+
if (config->on_token && piece) config->on_token(piece, config->user_data);
16019+
if (output && piece && output_pos + piece_len < output_size - 1) {
16020+
memcpy(output + output_pos, piece, piece_len);
16021+
output_pos += piece_len;
16022+
}
16023+
16024+
if (n_cached < *cached_capacity_io) {
16025+
cached[n_cached++] = next_token;
16026+
*n_cached_io = n_cached;
16027+
}
16028+
16029+
prev_token = next_token;
16030+
tq_forward(model, state, next_token, pos);
16031+
pos++;
16032+
generated++;
16033+
16034+
next_token = tq_sample_topp(state->logits, vocab_size,
16035+
config->temperature, config->top_p,
16036+
&rng_state);
16037+
}
16038+
16039+
if (output && output_size > 0) {
16040+
output[output_pos < output_size ? output_pos : output_size - 1] = '\0';
16041+
}
16042+
} else {
16043+
if (getenv("TQ_CHAT_DEBUG")) {
16044+
fprintf(stderr, "[chat-text] SLOW no text-prefix match, full tokenize\n");
16045+
}
16046+
generated = tq_generate_continue(
16047+
model, tokenizer, state, prompt, config,
16048+
cached_tokens_io, n_cached_io, cached_capacity_io,
16049+
output, output_size);
16050+
}
16051+
16052+
update_cache:
16053+
config->on_token = orig_cb;
16054+
config->user_data = orig_ud;
16055+
16056+
if (cached_text_io) {
16057+
size_t plen = strlen(prompt);
16058+
size_t glen = accum.len;
16059+
size_t new_len = plen + glen;
16060+
char* nt = (char*)malloc(new_len + 1);
16061+
if (nt) {
16062+
memcpy(nt, prompt, plen);
16063+
if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen);
16064+
nt[new_len] = '\0';
16065+
if (*cached_text_io) free(*cached_text_io);
16066+
*cached_text_io = nt;
16067+
}
16068+
}
16069+
if (accum.buf) free(accum.buf);
16070+
16071+
return generated;
16072+
}
16073+
1585116074
// ============================================================================
1585216075

1585316076
// ============================================================================
@@ -16182,6 +16405,7 @@ void quant_free_ctx(quant_ctx* ctx) {
1618216405
tq_free_state(ctx->state);
1618316406
tq_free_tokenizer(ctx->tokenizer);
1618416407
if (ctx->cached_tokens) free(ctx->cached_tokens);
16408+
if (ctx->cached_text) free(ctx->cached_text);
1618516409
free(ctx);
1618616410
}
1618716411

@@ -16217,6 +16441,7 @@ int quant_chat(quant_ctx* ctx, const char* prompt,
1621716441
ctx->n_cached = 0;
1621816442
ctx->cached_capacity = 0;
1621916443
ctx->n_ctx_tokens = 0;
16444+
if (ctx->cached_text) { free(ctx->cached_text); ctx->cached_text = NULL; }
1622016445
return 0;
1622116446
}
1622216447

@@ -16231,8 +16456,11 @@ int quant_chat(quant_ctx* ctx, const char* prompt,
1623116456
ctx->config.user_data = user_data;
1623216457

1623316458
char output[65536];
16234-
int n = tq_generate_continue(
16459+
/* Use the text-prefix path so chat replays bypass BPE re-tokenization
16460+
* issues. Falls back to token-LCP path if text prefix doesn't match. */
16461+
int n = tq_generate_chat_text(
1623516462
ctx->model, ctx->tokenizer, ctx->state, prompt, &ctx->config,
16463+
&ctx->cached_text,
1623616464
&ctx->cached_tokens, &ctx->n_cached, &ctx->cached_capacity,
1623716465
output, sizeof(output));
1623816466

wasm/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ emcc "$SCRIPT_DIR/quant_wasm.c" \
2626
-s ALLOW_MEMORY_GROWTH=1 \
2727
-s MAXIMUM_MEMORY=4GB \
2828
-s INITIAL_MEMORY=256MB \
29-
-s EXPORTED_FUNCTIONS='["_main","_wasm_load_model","_wasm_generate","_wasm_generate_async","_wasm_model_info","_wasm_is_ready","_malloc","_free"]' \
29+
-s EXPORTED_FUNCTIONS='["_main","_wasm_load_model","_wasm_generate","_wasm_generate_async","_wasm_reset_chat","_wasm_model_info","_wasm_is_ready","_malloc","_free"]' \
3030
-s EXPORTED_RUNTIME_METHODS='["UTF8ToString","allocateUTF8","FS","ccall","cwrap"]' \
3131
-s FORCE_FILESYSTEM=1 \
3232
-s MODULARIZE=0 \

wasm/index.html

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,18 @@ <h2>Run an <span>LLM</span> in your browser</h2>
389389
return `<|im_start|>user\n${text}<|im_end|>\n<|im_start|>assistant\n`;
390390
}
391391

392+
/* Multi-turn chat history. Sent on every turn so the model has context.
393+
* The C side's quant_chat() does text-prefix matching: turn N's prefill
394+
* is O(new tokens since last call), not O(full history). */
395+
let chatHistory = '';
396+
397+
function resetChatSession() {
398+
chatHistory = '';
399+
if (typeof Module !== 'undefined' && Module._wasm_reset_chat) {
400+
Module._wasm_reset_chat();
401+
}
402+
}
403+
392404
function stopGeneration() { stopRequested = true; }
393405

394406
async function generate() {
@@ -405,11 +417,14 @@ <h2>Run an <span>LLM</span> in your browser</h2>
405417

406418
addMessage('user', text);
407419
const aDiv = addMessage('assistant', '');
408-
aDiv.innerHTML = '<span class="thinking"><span class="spinner"></span> Processing prompt (may take a few seconds)...</span>';
420+
const isFirstTurn = chatHistory.length === 0;
421+
aDiv.innerHTML = isFirstTurn
422+
? '<span class="thinking"><span class="spinner"></span> Processing prompt (first turn — may take a few seconds)...</span>'
423+
: '<span class="thinking"><span class="spinner"></span> Generating...</span>';
409424
let output = '', count = 0;
410425
const t0 = performance.now();
411426
document.getElementById('statTokens').textContent = '';
412-
document.getElementById('statSpeed').textContent = 'processing prompt...';
427+
document.getElementById('statSpeed').textContent = isFirstTurn ? 'processing prompt...' : 'generating...';
413428

414429
Module.onToken = (tok) => {
415430
output += tok; count++;
@@ -433,11 +448,13 @@ <h2>Run an <span>LLM</span> in your browser</h2>
433448

434449
await new Promise(r => requestAnimationFrame(() => requestAnimationFrame(r)));
435450

436-
const prompt = getChatPrompt(text);
451+
/* Build the full ChatML prompt by appending this turn to the history.
452+
* The C side's quant_chat() does text-prefix matching, so the previous
453+
* turns are reused from the KV cache — only the new user message gets
454+
* prefilled. Turn N's latency: O(new user message), not O(full history). */
455+
chatHistory += `<|im_start|>user\n${text}<|im_end|>\n<|im_start|>assistant\n`;
456+
const prompt = chatHistory;
437457

438-
// Use ccall with async:true — this is the correct way to call
439-
// ASYNCIFY-enabled C functions. Module._fn() direct calls do NOT
440-
// return Promises; only ccall({async:true}) does.
441458
try {
442459
await Module.ccall(
443460
'wasm_generate_async',
@@ -450,6 +467,12 @@ <h2>Run an <span>LLM</span> in your browser</h2>
450467
console.error('generate error:', e);
451468
}
452469

470+
/* Append the model's response to history so the next turn matches
471+
* the cached_text prefix exactly (byte-for-byte). */
472+
if (output) {
473+
chatHistory += `${output}<|im_end|>\n`;
474+
}
475+
453476
if (!output && !count) {
454477
aDiv.innerHTML = '<em style="color:#555">No output. Try a different prompt.</em>';
455478
}

wasm/quant.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

wasm/quant.wasm

426 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)