Skip to content

Commit 471a5f4

Browse files
unamedkrclaude
andauthored
feat: text-prefix matching — bypass BPE re-tokenization in chat reuse (#50)
Follow-up to PR #49. The token-level LCP path in tq_generate_continue has a fundamental limitation: model-generated tokens (sample_topp) and text-encoded tokens (tq_encode of the response in the next turn) can diverge due to BPE merge non-roundtripping. This caps per-turn LCP at the prompt boundary (~10 tokens), so longer histories still incur mostly-full reprefill. Fix: tq_generate_chat_text() — text-level prefix matching. How it works: 1. Each session stores the entire prompt+response text from the previous call (cached_text). 2. On a new request, check if the new prompt starts with cached_text byte-for-byte. If yes, the cached state is byte-equivalent valid. 3. Tokenize ONLY the suffix (new_prompt[strlen(cached_text):]) and prefill those tokens at positions [n_cached..n_cached + n_suffix). 4. Run generation. The accumulated output text gets appended to cached_text via a tee callback for the next call. 5. If text prefix doesn't match, fall back to tq_generate_continue (token LCP path). Bug fix bundled: json_find_key("user") was matching the value in {"role":"user"} instead of the top-level "user" key. Result: every request used the "default" session, so multi-session was effectively broken (cross-pollution). The fix scans for "key": (with colon) to disambiguate from value matches. Measured (SmolLM2-135M, single thread, real chat replay): Single user, 10-turn accumulation: PR #48 (token LCP only): turn 10 → 3700 ms PR #49 (above + multi-session): turn 10 → 3700 ms (LCP still capped) This PR (text-prefix path): turn 10 → 739 ms (5x) alice + bob interleaved, 5 turns each (real assistant replay): PR #49: alice 5 = 2412 ms, bob 5 = 2357 ms Now: alice 5 = 498 ms, bob 5 = 462 ms (5x) The growth that remains (~50ms/turn) is the unavoidable O(n) cost of the attention computation over the full context — KV prefill is now truly O(new tokens per turn), not O(full history per turn). Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 361751d commit 471a5f4

2 files changed

Lines changed: 295 additions & 19 deletions

File tree

src/engine/tq_generate.c

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,3 +802,255 @@ int tq_generate_continue(tq_model_t* model,
802802
free(new_tokens);
803803
return generated;
804804
}
805+
806+
/* ============================================================================
807+
* tq_generate_chat_text — text-prefix matching for chat reuse
808+
*
809+
* Solves the BPE re-tokenization issue: when the model generates response
810+
* tokens via sample_topp, those token IDs may not match what tq_encode()
811+
* produces from the same response text in the next turn's prompt. The
812+
* token-level LCP in tq_generate_continue truncates at that boundary.
813+
*
814+
* This function tracks the *text* of the last prompt (which includes the
815+
* model's response from previous turns, accumulated by the caller). On the
816+
* next call, if the new prompt starts with cached_text byte-for-byte, the
817+
* entire cached state is valid — we tokenize only the new SUFFIX text and
818+
* prefill those tokens at positions [n_cached..]. No LCP, no truncation.
819+
*
820+
* After generation, *cached_text_io is updated to:
821+
* prompt + (generated tokens decoded back to text)
822+
* so the next call can fast-path again.
823+
*
824+
* Caller owns *cached_text_io (must free with free()).
825+
* Pass cached_text_io == NULL to disable text-prefix tracking and behave
826+
* exactly like tq_generate_continue.
827+
* ============================================================================ */
828+
829+
typedef struct {
830+
char* buf;
831+
size_t len;
832+
size_t cap;
833+
void (*user_cb)(const char*, void*);
834+
void* user_data;
835+
} chat_accum_t;
836+
837+
static void chat_accum_callback(const char* tok, void* u) {
838+
chat_accum_t* ctx = (chat_accum_t*)u;
839+
if (!tok) return;
840+
size_t tlen = strlen(tok);
841+
if (ctx->len + tlen + 1 > ctx->cap) {
842+
size_t new_cap = (ctx->cap + tlen + 64) * 2;
843+
char* nb = (char*)realloc(ctx->buf, new_cap);
844+
if (!nb) return;
845+
ctx->buf = nb;
846+
ctx->cap = new_cap;
847+
}
848+
memcpy(ctx->buf + ctx->len, tok, tlen);
849+
ctx->len += tlen;
850+
ctx->buf[ctx->len] = '\0';
851+
if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data);
852+
}
853+
854+
int tq_generate_chat_text(tq_model_t* model,
855+
tq_tokenizer_t* tokenizer,
856+
tq_state_t* state,
857+
const char* prompt,
858+
tq_gen_config_t* config,
859+
char** cached_text_io,
860+
int** cached_tokens_io,
861+
int* n_cached_io,
862+
int* cached_capacity_io,
863+
char* output, int output_size) {
864+
if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt) {
865+
return -1;
866+
}
867+
868+
/* --- 1. Check for text-level prefix match --- */
869+
int matched_text_len = 0;
870+
int prefix_pos = 0; /* tokens already in KV cache that we trust */
871+
872+
if (cached_text_io && *cached_text_io && *n_cached_io > 0) {
873+
size_t cached_len = strlen(*cached_text_io);
874+
if (cached_len > 0 && strncmp(*cached_text_io, prompt, cached_len) == 0) {
875+
matched_text_len = (int)cached_len;
876+
prefix_pos = *n_cached_io;
877+
} else if (getenv("TQ_CHAT_DEBUG")) {
878+
/* Find where they diverge to help diagnose */
879+
size_t diverge = 0;
880+
size_t plen = strlen(prompt);
881+
size_t lim = cached_len < plen ? cached_len : plen;
882+
while (diverge < lim && (*cached_text_io)[diverge] == prompt[diverge]) diverge++;
883+
fprintf(stderr,
884+
"[chat-text] no match: cached_len=%zu prompt_len=%zu diverge_at=%zu\n"
885+
" cached[%zu..]: %.40s\n"
886+
" prompt[%zu..]: %.40s\n",
887+
cached_len, plen, diverge,
888+
diverge, *cached_text_io + diverge,
889+
diverge, prompt + diverge);
890+
}
891+
}
892+
893+
/* Wrap user callback to capture generated text into a buffer for the
894+
* next call's cached_text update. */
895+
chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0,
896+
.user_cb = config->on_token,
897+
.user_data = config->user_data };
898+
void (*orig_cb)(const char*, void*) = config->on_token;
899+
void* orig_ud = config->user_data;
900+
config->on_token = chat_accum_callback;
901+
config->user_data = &accum;
902+
903+
int generated = 0;
904+
905+
if (matched_text_len > 0) {
906+
/* --- Fast path: text prefix matches --- */
907+
const char* suffix = prompt + matched_text_len;
908+
int max_prompt = model->config.max_seq_len > 0
909+
? model->config.max_seq_len : 4096;
910+
int* suffix_toks = (int*)malloc((size_t)max_prompt * sizeof(int));
911+
if (!suffix_toks) {
912+
config->on_token = orig_cb; config->user_data = orig_ud;
913+
return -1;
914+
}
915+
int n_suffix = 0;
916+
if (*suffix != '\0') {
917+
n_suffix = tq_encode(tokenizer, suffix, suffix_toks, max_prompt, 0);
918+
if (n_suffix < 0) n_suffix = 0;
919+
}
920+
921+
/* Sliding window if needed (drop from start of cached) */
922+
int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
923+
if (prefix_pos + n_suffix + reserve + 32 > max_prompt) {
924+
/* Force a full reprefill — simpler than partial cache shift */
925+
free(suffix_toks);
926+
config->on_token = orig_cb; config->user_data = orig_ud;
927+
*n_cached_io = 0;
928+
if (cached_text_io && *cached_text_io) {
929+
free(*cached_text_io); *cached_text_io = NULL;
930+
}
931+
int n2 = tq_generate_continue(model, tokenizer, state, prompt, config,
932+
cached_tokens_io, n_cached_io, cached_capacity_io,
933+
output, output_size);
934+
/* fall-through path captures cached_text below */
935+
generated = n2;
936+
goto update_cache;
937+
}
938+
939+
/* Grow cache buffer */
940+
int needed = prefix_pos + n_suffix + reserve + 16;
941+
if (*cached_capacity_io < needed) {
942+
int new_cap = needed < 4096 ? 4096 : needed;
943+
int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int));
944+
if (!nb) { free(suffix_toks); config->on_token = orig_cb; config->user_data = orig_ud; return -1; }
945+
*cached_tokens_io = nb;
946+
*cached_capacity_io = new_cap;
947+
}
948+
949+
/* Append suffix tokens to cache + prefill at correct positions */
950+
int* cached = *cached_tokens_io;
951+
for (int i = 0; i < n_suffix; i++) {
952+
cached[prefix_pos + i] = suffix_toks[i];
953+
tq_forward(model, state, suffix_toks[i], prefix_pos + i);
954+
}
955+
*n_cached_io = prefix_pos + n_suffix;
956+
free(suffix_toks);
957+
958+
if (getenv("TQ_CHAT_DEBUG")) {
959+
fprintf(stderr, "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n",
960+
matched_text_len, n_suffix);
961+
}
962+
963+
/* --- Run generation loop directly --- */
964+
int vocab_size = model->config.vocab_size;
965+
int n_cached = *n_cached_io;
966+
int pos = n_cached;
967+
int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1;
968+
969+
unsigned long long rng_state = config->rng_seed
970+
? (unsigned long long)config->rng_seed : (unsigned long long)time(NULL);
971+
int next_token = tq_sample_topp(state->logits, vocab_size,
972+
config->temperature, config->top_p,
973+
&rng_state);
974+
975+
int output_pos = 0;
976+
int eos_tokens[] = { 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046 };
977+
int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]);
978+
979+
while (generated < config->max_tokens) {
980+
int is_eos = 0;
981+
for (int e = 0; e < n_eos; e++) {
982+
if (next_token == eos_tokens[e]) { is_eos = 1; break; }
983+
}
984+
if (is_eos) break;
985+
if (pos >= model->config.max_seq_len) break;
986+
987+
const char* piece = tokenizer ? tq_decode(tokenizer, prev_token, next_token) : "";
988+
int should_stop = 0;
989+
if (piece) {
990+
if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") ||
991+
strstr(piece, "<|start_header_id|>")) {
992+
should_stop = 1; piece = "";
993+
}
994+
}
995+
if (should_stop) break;
996+
997+
int piece_len = (int)strlen(piece ? piece : "");
998+
if (config->on_token && piece) config->on_token(piece, config->user_data);
999+
if (output && piece && output_pos + piece_len < output_size - 1) {
1000+
memcpy(output + output_pos, piece, piece_len);
1001+
output_pos += piece_len;
1002+
}
1003+
1004+
if (n_cached < *cached_capacity_io) {
1005+
cached[n_cached++] = next_token;
1006+
*n_cached_io = n_cached;
1007+
}
1008+
1009+
prev_token = next_token;
1010+
tq_forward(model, state, next_token, pos);
1011+
pos++;
1012+
generated++;
1013+
1014+
next_token = tq_sample_topp(state->logits, vocab_size,
1015+
config->temperature, config->top_p,
1016+
&rng_state);
1017+
}
1018+
1019+
if (output && output_size > 0) {
1020+
output[output_pos < output_size ? output_pos : output_size - 1] = '\0';
1021+
}
1022+
} else {
1023+
/* --- Slow path: no text-prefix match, use token LCP fallback --- */
1024+
if (getenv("TQ_CHAT_DEBUG")) {
1025+
fprintf(stderr, "[chat-text] SLOW no text-prefix match, full tokenize\n");
1026+
}
1027+
generated = tq_generate_continue(
1028+
model, tokenizer, state, prompt, config,
1029+
cached_tokens_io, n_cached_io, cached_capacity_io,
1030+
output, output_size);
1031+
}
1032+
1033+
update_cache:
1034+
/* Restore the original callback before returning to caller */
1035+
config->on_token = orig_cb;
1036+
config->user_data = orig_ud;
1037+
1038+
/* Update cached_text = prompt + generated text. The next call can
1039+
* fast-path against this if its prompt starts with this string. */
1040+
if (cached_text_io) {
1041+
size_t plen = strlen(prompt);
1042+
size_t glen = accum.len;
1043+
size_t new_len = plen + glen;
1044+
char* nt = (char*)malloc(new_len + 1);
1045+
if (nt) {
1046+
memcpy(nt, prompt, plen);
1047+
if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen);
1048+
nt[new_len] = '\0';
1049+
if (*cached_text_io) free(*cached_text_io);
1050+
*cached_text_io = nt;
1051+
}
1052+
}
1053+
if (accum.buf) free(accum.buf);
1054+
1055+
return generated;
1056+
}

src/server/tq_server.c

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
#include <stdarg.h>
2020
#include <stdbool.h>
2121

22-
/* Forward decl: defined in src/engine/tq_generate.c.
23-
* Not yet exposed in turboquant.h since it's a chat-mode helper. */
22+
/* Forward decls: defined in src/engine/tq_generate.c.
23+
* Not yet exposed in turboquant.h since they're chat-mode helpers. */
2424
extern int tq_generate_continue(tq_model_t* model,
2525
tq_tokenizer_t* tokenizer,
2626
tq_state_t* state,
@@ -30,6 +30,18 @@ extern int tq_generate_continue(tq_model_t* model,
3030
int* n_cached_io,
3131
int* cached_capacity_io,
3232
char* output, int output_size);
33+
34+
/* Text-prefix matching variant — solves BPE re-tokenization mismatch. */
35+
extern int tq_generate_chat_text(tq_model_t* model,
36+
tq_tokenizer_t* tokenizer,
37+
tq_state_t* state,
38+
const char* prompt,
39+
tq_gen_config_t* config,
40+
char** cached_text_io,
41+
int** cached_tokens_io,
42+
int* n_cached_io,
43+
int* cached_capacity_io,
44+
char* output, int output_size);
3345
#if defined(_MSC_VER)
3446
#include <intrin.h>
3547
typedef volatile long atomic_int;
@@ -95,6 +107,7 @@ typedef struct {
95107
int* cached_tokens;
96108
int n_cached;
97109
int cached_capacity;
110+
char* cached_text; /* prompt + generated, for text-prefix matching */
98111
long last_used; /* monotonic counter for LRU */
99112
} kv_session_t;
100113

@@ -144,6 +157,7 @@ static kv_session_t* get_or_create_session(tq_server_t* server,
144157
/* Free old session contents (if any) */
145158
if (s->kv_state) tq_free_state(s->kv_state);
146159
if (s->cached_tokens) free(s->cached_tokens);
160+
if (s->cached_text) free(s->cached_text);
147161

148162
memset(s, 0, sizeof(*s));
149163
strncpy(s->id, sid, SESSION_ID_MAX - 1);
@@ -237,15 +251,22 @@ static const char* json_extract_string(const char* p, char* buf, int buf_size) {
237251
/* Find a key in JSON and return pointer to value (past the colon).
238252
* Simple scan — works for flat or lightly nested objects. */
239253
static const char* json_find_key(const char* json, const char* key) {
254+
/* Find a "key": pattern. Naive scan: locate every "key" occurrence
255+
* and verify the next non-whitespace char is ':'. This skips false
256+
* matches where "key" appears as a *value* (e.g., {"role":"user"}
257+
* collides with json_find_key("user") if we don't check the colon). */
240258
char pattern[256];
241259
snprintf(pattern, sizeof(pattern), "\"%s\"", key);
242-
const char* p = strstr(json, pattern);
243-
if (!p) return NULL;
244-
p += strlen(pattern);
245-
p = json_skip_ws(p);
246-
if (*p != ':') return NULL;
247-
p++;
248-
return json_skip_ws(p);
260+
size_t plen = strlen(pattern);
261+
const char* p = json;
262+
while ((p = strstr(p, pattern)) != NULL) {
263+
const char* after = json_skip_ws(p + plen);
264+
if (*after == ':') {
265+
return json_skip_ws(after + 1);
266+
}
267+
p += plen; /* skip past this false match and keep searching */
268+
}
269+
return NULL;
249270
}
250271

251272
/* Extract a number (int or float) from current position */
@@ -758,11 +779,12 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod
758779
kv_session_t* sess = get_or_create_session(server, req.session_id,
759780
gen_cfg.kv_type,
760781
gen_cfg.value_quant_bits);
761-
tq_generate_continue(server->config.model, server->config.tokenizer,
762-
sess->kv_state, req.prompt, &gen_cfg,
763-
&sess->cached_tokens, &sess->n_cached,
764-
&sess->cached_capacity,
765-
output, sizeof(output));
782+
tq_generate_chat_text(server->config.model, server->config.tokenizer,
783+
sess->kv_state, req.prompt, &gen_cfg,
784+
&sess->cached_text,
785+
&sess->cached_tokens, &sess->n_cached,
786+
&sess->cached_capacity,
787+
output, sizeof(output));
766788

767789
/* Send final chunk with finish_reason */
768790
char final_chunk[SSE_CHUNK_SIZE];
@@ -795,11 +817,12 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod
795817
kv_session_t* sess = get_or_create_session(server, req.session_id,
796818
gen_cfg.kv_type,
797819
gen_cfg.value_quant_bits);
798-
tq_generate_continue(server->config.model, server->config.tokenizer,
799-
sess->kv_state, req.prompt, &gen_cfg,
800-
&sess->cached_tokens, &sess->n_cached,
801-
&sess->cached_capacity,
802-
output, sizeof(output));
820+
tq_generate_chat_text(server->config.model, server->config.tokenizer,
821+
sess->kv_state, req.prompt, &gen_cfg,
822+
&sess->cached_text,
823+
&sess->cached_tokens, &sess->n_cached,
824+
&sess->cached_capacity,
825+
output, sizeof(output));
803826

804827
const char* content = collect.buf ? collect.buf : "";
805828

@@ -1260,6 +1283,7 @@ void tq_server_free(tq_server_t* server) {
12601283
for (int i = 0; i < MAX_SESSIONS; i++) {
12611284
if (server->sessions[i].kv_state) tq_free_state(server->sessions[i].kv_state);
12621285
if (server->sessions[i].cached_tokens) free(server->sessions[i].cached_tokens);
1286+
if (server->sessions[i].cached_text) free(server->sessions[i].cached_text);
12631287
}
12641288
if (g_server == server) g_server = NULL;
12651289
free(server);

0 commit comments

Comments
 (0)