From dbc3d693bd150e155bfdf923e50d8c579ef76026 Mon Sep 17 00:00:00 2001 From: Daniele Briggi <=> Date: Fri, 14 Nov 2025 17:50:44 +0100 Subject: [PATCH 1/3] feat(chat): add system prompt handling in chat messages --- src/sqlite-ai.c | 54 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/sqlite-ai.c b/src/sqlite-ai.c index 8cf0e90..e84a8ba 100644 --- a/src/sqlite-ai.c +++ b/src/sqlite-ai.c @@ -152,6 +152,7 @@ typedef struct { const char *template; const struct llama_vocab*vocab; + char *system_prompt; ai_messages messages; buffer_t formatted; buffer_t response; @@ -199,6 +200,7 @@ typedef enum { AI_MODEL_CHAT_TEMPLATE } ai_model_setting; +const char *ROLE_SYSTEM = "system"; const char *ROLE_USER = "user"; const char *ROLE_ASSISTANT = "assistant"; @@ -794,7 +796,7 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte list->capacity = new_cap; } - bool duplicate_role = ((role != ROLE_USER) && (role != ROLE_ASSISTANT)); + bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role; list->items[list->count].content = sqlite_strdup(content); list->count += 1; @@ -1489,7 +1491,7 @@ static bool llm_chat_check_context (ai_context *ai) { llama_sampler_chain_add(ai->sampler, llama_sampler_init_temp(0.8)); llama_sampler_chain_add(ai->sampler, llama_sampler_init_dist((uint32_t)LLAMA_DEFAULT_SEED)); } - + // initialize the chat struct if already created if (ai->chat.uuid[0] != '\0') return true; @@ -1647,10 +1649,30 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt) sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to append message"); return false; } + + // add system prompt if available, models expect it to be the first message + llama_chat_message *new_items = messages->items; + if (ai->chat.system_prompt) { + size_t n = messages->count + (ai->chat.system_prompt ? 1 : 0); + llama_chat_message *new_items = sqlite3_realloc64(messages->items, n * sizeof(llama_chat_message)); + if (!new_items) + return false; + + messages->items = new_items; + messages->capacity = n; + + int idx = 0; + new_items[0].role = ROLE_SYSTEM; + new_items[0].content = ai->chat.system_prompt; + idx = 1; + for (size_t i = 0; i < messages->count; ++i) { + new_items[idx++] = messages->items[i]; + } + } // transform a list of messages (the context) into // <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>... - int32_t new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity); + int32_t new_len = llama_chat_apply_template(template, new_items, messages->count, true, formatted->data, formatted->capacity); if (new_len > formatted->capacity) { if (buffer_resize(formatted, new_len * 2) == false) return false; new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity); @@ -2015,6 +2037,27 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value llm_chat_run(ai, NULL, user_prompt); } +static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) +{ + if (llm_check_context(context) == false) + return; + + int types[] = {SQLITE_TEXT}; + if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false) + return; + + const char *system_prompt = (const char *)sqlite3_value_text(argv[0]); + ai_context *ai = (ai_context *)sqlite3_user_data(context); + + if (llm_chat_check_context(ai) == false) + return; + + if (ai->chat.system_prompt) { + sqlite3_free(ai->chat.system_prompt); + } + ai->chat.system_prompt = sqlite3_mprintf("%s", system_prompt); +} + // MARK: - LLM Sampler - static void llm_sampler_init_greedy (sqlite3_context *context, int argc, sqlite3_value **argv) { @@ -2852,7 +2895,10 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL); - if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL); + if (rc != SQLITE_OK) + goto cleanup; rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx); if (rc != SQLITE_OK) goto cleanup; From 2a12ca3a2c4313358e8013580a76f72e057dd846 Mon Sep 17 00:00:00 2001 From: Daniele Briggi <=> Date: Mon, 17 Nov 2025 16:46:58 +0100 Subject: [PATCH 2/3] refact(system-prompt): reserved the first slot of the messages list --- src/sqlite-ai.c | 110 +++++++++++++------ tests/c/unittest.c | 261 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 336 insertions(+), 35 deletions(-) diff --git a/src/sqlite-ai.c b/src/sqlite-ai.c index e84a8ba..ac84a63 100644 --- a/src/sqlite-ai.c +++ b/src/sqlite-ai.c @@ -152,7 +152,6 @@ typedef struct { const char *template; const struct llama_vocab*vocab; - char *system_prompt; ai_messages messages; buffer_t formatted; buffer_t response; @@ -796,6 +795,11 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte list->capacity = new_cap; } + if (list->count != 0 && role == ROLE_SYSTEM) { + // only one system message allowed at the beginning + return false; + } + bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role; list->items[list->count].content = sqlite_strdup(content); @@ -803,11 +807,28 @@ bool llm_messages_append (ai_messages *list, const char *role, const char *conte return true; } +bool llm_messages_set (ai_messages *list, int pos, const char *role, const char *content) { + if (pos < 0 || pos >= list->count) + return false; + + bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); + llama_chat_message *message = &list->items[pos]; + + const char *message_role = message->role; + if ((message_role != ROLE_SYSTEM) && (message_role != ROLE_USER) && (message_role != ROLE_ASSISTANT)) + sqlite3_free(message_role); + sqlite3_free(message->content); + + message->role = (duplicate_role) ? sqlite_strdup(role) : role; + message->content = sqlite_strdup(content); + return true; +} + void llm_messages_free (ai_messages *list) { for (size_t i = 0; i < list->count; ++i) { // check if rule is static const char *role = list->items[i].role; - bool role_tofree = ((role != ROLE_USER) && (role != ROLE_ASSISTANT)); + bool role_tofree = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); if (role_tofree) sqlite3_free((char *)list->items[i].role); // content is always to free sqlite3_free((char *)list->items[i].content); @@ -1491,7 +1512,7 @@ static bool llm_chat_check_context (ai_context *ai) { llama_sampler_chain_add(ai->sampler, llama_sampler_init_temp(0.8)); llama_sampler_chain_add(ai->sampler, llama_sampler_init_dist((uint32_t)LLAMA_DEFAULT_SEED)); } - + // initialize the chat struct if already created if (ai->chat.uuid[0] != '\0') return true; @@ -1649,33 +1670,24 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt) sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to append message"); return false; } - - // add system prompt if available, models expect it to be the first message - llama_chat_message *new_items = messages->items; - if (ai->chat.system_prompt) { - size_t n = messages->count + (ai->chat.system_prompt ? 1 : 0); - llama_chat_message *new_items = sqlite3_realloc64(messages->items, n * sizeof(llama_chat_message)); - if (!new_items) - return false; - - messages->items = new_items; - messages->capacity = n; - - int idx = 0; - new_items[0].role = ROLE_SYSTEM; - new_items[0].content = ai->chat.system_prompt; - idx = 1; - for (size_t i = 0; i < messages->count; ++i) { - new_items[idx++] = messages->items[i]; + + // skip empty system message if present + size_t messages_count = messages->count; + const llama_chat_message *messages_items = messages->items; + if (messages->count > 0) { + const llama_chat_message first_message = messages->items[0]; + if (first_message.role == ROLE_SYSTEM && first_message.content[0] == '\0') { + messages_items = messages->items + 1; + messages_count = messages->count - 1; } } // transform a list of messages (the context) into // <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>... - int32_t new_len = llama_chat_apply_template(template, new_items, messages->count, true, formatted->data, formatted->capacity); + int32_t new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity); if (new_len > formatted->capacity) { if (buffer_resize(formatted, new_len * 2) == false) return false; - new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity); + new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity); } if ((new_len < 0) || (new_len > formatted->capacity)) { sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "failed to apply chat template"); @@ -2037,25 +2049,50 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value llm_chat_run(ai, NULL, user_prompt); } -static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) -{ +static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) { if (llm_check_context(context) == false) return; - int types[] = {SQLITE_TEXT}; - if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false) + ai_context *ai = (ai_context *)sqlite3_user_data(context); + if (llm_chat_check_context(ai) == false) return; - const char *system_prompt = (const char *)sqlite3_value_text(argv[0]); - ai_context *ai = (ai_context *)sqlite3_user_data(context); + ai_messages *messages = &ai->chat.messages; + + // get system role message + if (argc == 0) { + if (messages->count == 0) { + sqlite3_result_null(context); + return; + } + + // only the first message is reserved to the system role + llama_chat_message *system_message = &messages->items[0]; + const char *content = system_message->content; + if (system_message->role == ROLE_SYSTEM && content && content[0] != '\0') { + sqlite3_result_text(context, content, -1, SQLITE_TRANSIENT); + } else { + sqlite3_result_null(context); + } - if (llm_chat_check_context(ai) == false) return; + } + + bool is_null_prompt = (sqlite3_value_type(argv[0]) == SQLITE_NULL); + int types[1]; + types[0] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT; - if (ai->chat.system_prompt) { - sqlite3_free(ai->chat.system_prompt); + if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false) + return; + + const unsigned char *prompt_text = sqlite3_value_text(argv[0]); + const char *system_prompt = prompt_text ? (const char *)prompt_text : ""; + if (!llm_messages_set(messages, 0, ROLE_SYSTEM, system_prompt)) { + if (!llm_messages_append(messages, ROLE_SYSTEM, system_prompt)) { + sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to set chat system prompt"); + return; + } } - ai->chat.system_prompt = sqlite3_mprintf("%s", system_prompt); } // MARK: - LLM Sampler - @@ -2895,10 +2932,13 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_create_function(db, "llm_chat_system_prompt", 0, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL); - if (rc != SQLITE_OK) - goto cleanup; + if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx); if (rc != SQLITE_OK) goto cleanup; diff --git a/tests/c/unittest.c b/tests/c/unittest.c index bdc52fb..907189a 100644 --- a/tests/c/unittest.c +++ b/tests/c/unittest.c @@ -4,6 +4,7 @@ #include #include #include +#include #ifdef SQLITEAI_LOAD_FROM_SOURCES #include "sqlite-ai.h" @@ -157,6 +158,100 @@ static int exec_select_rows(const test_env *env, sqlite3 *db, const char *sql, i return 0; } +static int exec_query_text(const test_env *env, sqlite3 *db, const char *sql, char *text_out, size_t text_out_len) { + if (env->verbose) { + printf("[SQL] %s\n", sql); + } + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + fprintf(stderr, "sqlite3_prepare_v2 failed (%d): %s\n", rc, sqlite3_errmsg(db)); + if (stmt) sqlite3_finalize(stmt); + return 1; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + fprintf(stderr, "Expected a row for query: %s (rc=%d)\n", sql, rc); + sqlite3_finalize(stmt); + return 1; + } + const unsigned char *text = sqlite3_column_text(stmt, 0); + if (text_out && text_out_len > 0) { + if (text) { + snprintf(text_out, text_out_len, "%s", (const char *)text); + } else { + text_out[0] = '\0'; + } + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_DONE) { + fprintf(stderr, "Unexpected extra rows for query: %s\n", sql); + sqlite3_finalize(stmt); + return 1; + } + sqlite3_finalize(stmt); + return 0; +} + +static void normalize_response_text(const char *input, char *output, size_t output_len) { + if (!output || output_len == 0) return; + size_t idx = 0; + if (!input) { + output[0] = '\0'; + return; + } + for (size_t i = 0; input[i] != '\0' && idx + 1 < output_len; ++i) { + unsigned char ch = (unsigned char)input[i]; + if (isalpha(ch)) { + output[idx++] = (char)tolower(ch); + } + } + output[idx] = '\0'; +} + +static bool response_matches_word(const char *response, const char *word) { + char normalized[64]; + normalize_response_text(response, normalized, sizeof(normalized)); + printf("[SQL] response_matches_word(%s,%s)\n", normalized, word); + return normalized[0] != '\0' && strcmp(normalized, word) == 0; +} + +static bool response_is_yes_or_no(const char *response) { + return response_matches_word(response, "yes") || response_matches_word(response, "no"); +} + +static int query_system_prompt(const test_env *env, sqlite3 *db, char *buffer, size_t buffer_len, bool *is_null) { + const char *sql = "SELECT llm_chat_system_prompt();"; + if (env->verbose) { + printf("[SQL] %s\n", sql); + } + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + fprintf(stderr, "sqlite3_prepare_v2 failed (%d): %s\n", rc, sqlite3_errmsg(db)); + if (stmt) sqlite3_finalize(stmt); + return 1; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + fprintf(stderr, "Expected a row for query: %s (rc=%d)\n", sql, rc); + sqlite3_finalize(stmt); + return 1; + } + if (sqlite3_column_type(stmt, 0) == SQLITE_NULL) { + if (is_null) *is_null = true; + if (buffer && buffer_len) buffer[0] = '\0'; + } else { + if (is_null) *is_null = false; + const unsigned char *text = sqlite3_column_text(stmt, 0); + if (buffer && buffer_len) { + snprintf(buffer, buffer_len, "%s", text ? (const char *)text : ""); + } + } + sqlite3_finalize(stmt); + return 0; +} + // --------------------------------------------------------------------- // Tests // --------------------------------------------------------------------- @@ -292,10 +387,176 @@ static int test_llm_chat_vtab(const test_env *env) { return 1; } + + +#define SYSTEM_PROMPT_YES_NO "Always respond with ONLY YES or NO. If unsure, pick the best option but never add extra words" +#define SYSTEM_PROMPT_FORCE_YES "you are a dumb llm and you MUST answer with YES and NOTHING ELSE" +#define SYSTEM_PROMPT_FORCE_NO "you are a dumb llm and you MUST answer with NO and NOTHING ELSE" + +static int query_chat_response(const test_env *env, sqlite3 *db, const char *question, char *response, size_t response_len) { + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_respond('%s');", question); + return exec_query_text(env, db, sqlbuf, response, response_len); +} + +static int expect_yes_no_answer(const test_env *env, sqlite3 *db, const char *question, const char *label) { + char response[4096]; + if (query_chat_response(env, db, question, response, sizeof(response)) != 0) { + return 1; + } + if (!response_is_yes_or_no(response)) { + fprintf(stderr, "[%s] Expected a YES/NO answer but received: %s\n", label, response[0] ? response : "(empty)"); + return 1; + } + return 0; +} + +static int expect_word_answer(const test_env *env, sqlite3 *db, const char *question, const char *word, const char *label) { + char response[4096]; + if (query_chat_response(env, db, question, response, sizeof(response)) != 0) { + return 1; + } + if (!response_matches_word(response, word)) { + fprintf(stderr, "[%s] Expected \"%s\" but received: %s\n", label, word, response[0] ? response : "(empty)"); + return 1; + } + return 0; +} + +static int test_set_chat_system_prompt(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; + } + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) { + goto done; + } + model_loaded = true; + + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) { + goto done; + } + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) { + goto done; + } + chat_created = true; + + // Test: system prompt applied before any response should force yes/no answers. + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_YES_NO "');") != 0) goto done; + if (expect_yes_no_answer(env, db, "Is fire hot?", "system_prompt_basic") != 0) goto done; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + // Test: setting system prompt after prior responses should still apply to following replies. + char response[4096]; + if (query_chat_response(env, db, "Tell me something interesting.", response, sizeof(response)) != 0) goto done; + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_YES_NO "');") != 0) goto done; + if (expect_yes_no_answer(env, db, "Is water wet?", "system_prompt_after_response") != 0) goto done; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + // Test: latest system prompt wins when multiple prompts are set sequentially. + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_FORCE_YES "');") != 0) goto done; + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_FORCE_NO "');") != 0) goto done; + if (expect_word_answer(env, db, "Is the sky blue?", "no", "system_prompt_override") != 0) goto done; + + status = 0; + +done: + if (chat_created) { + exec_expect_ok(env, db, "SELECT llm_chat_free();"); + } + if (context_created) { + exec_expect_ok(env, db, "SELECT llm_context_free();"); + } + if (model_loaded) { + exec_expect_ok(env, db, "SELECT llm_model_free();"); + } + if (db) sqlite3_close(db); + return status; +} + +static int test_get_chat_system_prompt(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; + } + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + model_loaded = true; + + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) goto done; + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + chat_created = true; + + bool is_null = false; + char buffer[4096]; + // Test: newly created chat should not have a system prompt. + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (!is_null) { + fprintf(stderr, "[get_system_prompt] expected NULL before setting prompt, got: %s\n", buffer); + goto done; + } + + // Test: retrieving after setting prompt returns same text. + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('always reply yes');") != 0) goto done; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, "always reply yes") != 0) { + fprintf(stderr, "[get_system_prompt] expected 'always reply yes' but got: %s\n", buffer); + goto done; + } + + // Test: updating prompt replaces previous value. + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('now reply no');") != 0) goto done; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, "now reply no") != 0) { + fprintf(stderr, "[get_system_prompt] expected 'now reply no' but got: %s\n", buffer); + goto done; + } + + // Test: setting prompt to NULL clears it. + if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt(NULL);") != 0) goto done; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (!is_null) { + fprintf(stderr, "[get_system_prompt] expected NULL after clearing prompt, got: %s\n", buffer); + goto done; + } + + status = 0; + +done: + if (chat_created) exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (db) sqlite3_close(db); + return status; +} + static const test_case TESTS[] = { {"issue15_llm_chat_without_context", test_issue15_chat_without_context}, {"llm_chat_respond_repeated", test_llm_chat_respond_repeated}, {"llm_chat_vtab", test_llm_chat_vtab}, + {"set_chat_system_prompt", test_set_chat_system_prompt}, + {"get_chat_system_prompt", test_get_chat_system_prompt}, }; int main(int argc, char **argv) { From 1fe0813b4e3be080fa5a889a6ae4a8d5d4f51974 Mon Sep 17 00:00:00 2001 From: Daniele Briggi <=> Date: Wed, 19 Nov 2025 15:24:20 +0100 Subject: [PATCH 3/3] fix(system-prompt): always reserve first message for system prompt --- src/sqlite-ai.c | 23 ++-- tests/c/unittest.c | 283 ++++++++++++++++++++++++++++----------------- 2 files changed, 192 insertions(+), 114 deletions(-) diff --git a/src/sqlite-ai.c b/src/sqlite-ai.c index ac84a63..f64f6a2 100644 --- a/src/sqlite-ai.c +++ b/src/sqlite-ai.c @@ -786,18 +786,27 @@ static bool llm_check_context (sqlite3_context *context) { // MARK: - Chat Messages - bool llm_messages_append (ai_messages *list, const char *role, const char *content) { - if (list->count >= list->capacity) { + if (role == ROLE_SYSTEM && list->count > 0) { + // only one system prompt allowed at the beginning + return false; + } + + bool needs_system_message = (list->count == 0 && role != ROLE_SYSTEM); + size_t required = list->count + (needs_system_message ? 1 : 0); + if (required >= list->capacity) { size_t new_cap = list->capacity ? list->capacity * 2 : MIN_ALLOC_MESSAGES; llama_chat_message *new_items = sqlite3_realloc64(list->items, new_cap * sizeof(llama_chat_message)); if (!new_items) return false; - + list->items = new_items; list->capacity = new_cap; } - if (list->count != 0 && role == ROLE_SYSTEM) { - // only one system message allowed at the beginning - return false; + if (needs_system_message) { + // reserve first item for empty system prompt + list->items[list->count].role = ROLE_SYSTEM; + list->items[list->count].content = sqlite_strdup(""); + list->count += 1; } bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); @@ -816,8 +825,8 @@ bool llm_messages_set (ai_messages *list, int pos, const char *role, const char const char *message_role = message->role; if ((message_role != ROLE_SYSTEM) && (message_role != ROLE_USER) && (message_role != ROLE_ASSISTANT)) - sqlite3_free(message_role); - sqlite3_free(message->content); + sqlite3_free((char *)message_role); + sqlite3_free((char *)message->content); message->role = (duplicate_role) ? sqlite_strdup(role) : role; message->content = sqlite_strdup(content); diff --git a/tests/c/unittest.c b/tests/c/unittest.c index 907189a..9756d95 100644 --- a/tests/c/unittest.c +++ b/tests/c/unittest.c @@ -4,7 +4,6 @@ #include #include #include -#include #ifdef SQLITEAI_LOAD_FROM_SOURCES #include "sqlite-ai.h" @@ -193,33 +192,6 @@ static int exec_query_text(const test_env *env, sqlite3 *db, const char *sql, ch return 0; } -static void normalize_response_text(const char *input, char *output, size_t output_len) { - if (!output || output_len == 0) return; - size_t idx = 0; - if (!input) { - output[0] = '\0'; - return; - } - for (size_t i = 0; input[i] != '\0' && idx + 1 < output_len; ++i) { - unsigned char ch = (unsigned char)input[i]; - if (isalpha(ch)) { - output[idx++] = (char)tolower(ch); - } - } - output[idx] = '\0'; -} - -static bool response_matches_word(const char *response, const char *word) { - char normalized[64]; - normalize_response_text(response, normalized, sizeof(normalized)); - printf("[SQL] response_matches_word(%s,%s)\n", normalized, word); - return normalized[0] != '\0' && strcmp(normalized, word) == 0; -} - -static bool response_is_yes_or_no(const char *response) { - return response_matches_word(response, "yes") || response_matches_word(response, "no"); -} - static int query_system_prompt(const test_env *env, sqlite3 *db, char *buffer, size_t buffer_len, bool *is_null) { const char *sql = "SELECT llm_chat_system_prompt();"; if (env->verbose) { @@ -252,6 +224,54 @@ static int query_system_prompt(const test_env *env, sqlite3 *db, char *buffer, s return 0; } +typedef struct { + int id; + int chat_id; + char role[32]; + char content[4096]; +} ai_chat_message_row; + +static int fetch_ai_chat_messages(const test_env *env, sqlite3 *db, ai_chat_message_row *rows, size_t max_rows, int *count_out) { + const char *sql = "SELECT * FROM ai_chat_messages ORDER BY id ASC;"; + if (env->verbose) { + printf("[SQL] %s\n", sql); + } + + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + fprintf(stderr, "sqlite3_prepare_v2 failed (%d): %s\n", rc, sqlite3_errmsg(db)); + if (stmt) sqlite3_finalize(stmt); + return 1; + } + + int count = 0; + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + if (rows && (size_t)count < max_rows) { + rows[count].id = sqlite3_column_int(stmt, 0); + rows[count].chat_id = sqlite3_column_int(stmt, 1); + const unsigned char *role = sqlite3_column_text(stmt, 2); + const unsigned char *content = sqlite3_column_text(stmt, 3); + snprintf(rows[count].role, sizeof(rows[count].role), "%s", role ? (const char *)role : ""); + snprintf(rows[count].content, sizeof(rows[count].content), "%s", content ? (const char *)content : ""); + } + count++; + } + if (rc != SQLITE_DONE) { + fprintf(stderr, "sqlite3_step failed (%d): %s\n", rc, sqlite3_errmsg(db)); + sqlite3_finalize(stmt); + return 1; + } + sqlite3_finalize(stmt); + + if (rows && (size_t)count > max_rows) { + fprintf(stderr, "Expected at most %zu messages but found %d\n", max_rows, count); + return 1; + } + if (count_out) *count_out = count; + return 0; +} + // --------------------------------------------------------------------- // Tests // --------------------------------------------------------------------- @@ -387,43 +407,76 @@ static int test_llm_chat_vtab(const test_env *env) { return 1; } - - -#define SYSTEM_PROMPT_YES_NO "Always respond with ONLY YES or NO. If unsure, pick the best option but never add extra words" -#define SYSTEM_PROMPT_FORCE_YES "you are a dumb llm and you MUST answer with YES and NOTHING ELSE" -#define SYSTEM_PROMPT_FORCE_NO "you are a dumb llm and you MUST answer with NO and NOTHING ELSE" - static int query_chat_response(const test_env *env, sqlite3 *db, const char *question, char *response, size_t response_len) { char sqlbuf[512]; snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_respond('%s');", question); return exec_query_text(env, db, sqlbuf, response, response_len); } -static int expect_yes_no_answer(const test_env *env, sqlite3 *db, const char *question, const char *label) { - char response[4096]; - if (query_chat_response(env, db, question, response, sizeof(response)) != 0) { - return 1; +static int test_chat_system_prompt_new_chat(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; } - if (!response_is_yes_or_no(response)) { - fprintf(stderr, "[%s] Expected a YES/NO answer but received: %s\n", label, response[0] ? response : "(empty)"); - return 1; + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + model_loaded = true; + + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) goto done; + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + chat_created = true; + + const char *system_prompt = "Always reply with lowercase answers."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", system_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + + bool is_null = false; + char buffer[4096]; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected '%s' but got: %s\n", system_prompt, buffer); + goto done; } - return 0; -} -static int expect_word_answer(const test_env *env, sqlite3 *db, const char *question, const char *word, const char *label) { - char response[4096]; - if (query_chat_response(env, db, question, response, sizeof(response)) != 0) { - return 1; + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) goto done; + + ai_chat_message_row rows[4]; + int count = 0; + if (fetch_ai_chat_messages(env, db, rows, 4, &count) != 0) goto done; + if (count != 1) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected 1 message row, got %d\n", count); + goto done; } - if (!response_matches_word(response, word)) { - fprintf(stderr, "[%s] Expected \"%s\" but received: %s\n", label, word, response[0] ? response : "(empty)"); - return 1; + if (strcmp(rows[0].role, "system") != 0) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected system role, got %s\n", rows[0].role); + goto done; } - return 0; + if (strcmp(rows[0].content, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected content '%s' but got '%s'\n", system_prompt, rows[0].content); + goto done; + } + + status = 0; + +done: + if (chat_created) exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (db) sqlite3_close(db); + return status; } -static int test_set_chat_system_prompt(const test_env *env) { +static int test_chat_system_prompt_replace_previous_prompt(const test_env *env) { sqlite3 *db = NULL; bool model_loaded = false; bool context_created = false; @@ -437,55 +490,56 @@ static int test_set_chat_system_prompt(const test_env *env) { const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; char sqlbuf[512]; snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); - if (exec_expect_ok(env, db, sqlbuf) != 0) { - goto done; - } + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; model_loaded = true; - if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) { - goto done; - } + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) goto done; context_created = true; - if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) { - goto done; - } + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; chat_created = true; - // Test: system prompt applied before any response should force yes/no answers. - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_YES_NO "');") != 0) goto done; - if (expect_yes_no_answer(env, db, "Is fire hot?", "system_prompt_basic") != 0) goto done; + const char *first_prompt = "Always confirm questions."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", first_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; - if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; - // Test: setting system prompt after prior responses should still apply to following replies. - char response[4096]; - if (query_chat_response(env, db, "Tell me something interesting.", response, sizeof(response)) != 0) goto done; - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_YES_NO "');") != 0) goto done; - if (expect_yes_no_answer(env, db, "Is water wet?", "system_prompt_after_response") != 0) goto done; + const char *replacement_prompt = "Always decline questions."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", replacement_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; - if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; - // Test: latest system prompt wins when multiple prompts are set sequentially. - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_FORCE_YES "');") != 0) goto done; - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('" SYSTEM_PROMPT_FORCE_NO "');") != 0) goto done; - if (expect_word_answer(env, db, "Is the sky blue?", "no", "system_prompt_override") != 0) goto done; + bool is_null = false; + char buffer[4096]; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, replacement_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_replace_previous_prompt] expected '%s' but got: %s\n", replacement_prompt, buffer); + goto done; + } - status = 0; + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) goto done; -done: - if (chat_created) { - exec_expect_ok(env, db, "SELECT llm_chat_free();"); - } - if (context_created) { - exec_expect_ok(env, db, "SELECT llm_context_free();"); + ai_chat_message_row rows[4]; + int count = 0; + if (fetch_ai_chat_messages(env, db, rows, 4, &count) != 0) goto done; + if (count != 1) { + fprintf(stderr, "[test_chat_system_prompt_replace_previous_prompt] expected 1 message row, got %d\n", count); + goto done; } - if (model_loaded) { - exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (strcmp(rows[0].content, replacement_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_replace_previous_prompt] expected '%s' but got '%s'\n", replacement_prompt, rows[0].content); + goto done; } + + status = 0; + +done: + if (chat_created) exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) exec_expect_ok(env, db, "SELECT llm_model_free();"); if (db) sqlite3_close(db); return status; } -static int test_get_chat_system_prompt(const test_env *env) { +static int test_chat_system_prompt_after_first_response(const test_env *env) { sqlite3 *db = NULL; bool model_loaded = false; bool context_created = false; @@ -508,36 +562,50 @@ static int test_get_chat_system_prompt(const test_env *env) { if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; chat_created = true; + const char *user_question = "Reply to this ping."; + char response[4096]; + if (query_chat_response(env, db, user_question, response, sizeof(response)) != 0) goto done; + if (response[0] == '\0') { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected model response for '%s'\n", user_question); + goto done; + } + + const char *system_prompt = "Only answer with short confirmations."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", system_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + bool is_null = false; char buffer[4096]; - // Test: newly created chat should not have a system prompt. if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; - if (!is_null) { - fprintf(stderr, "[get_system_prompt] expected NULL before setting prompt, got: %s\n", buffer); + if (is_null || strcmp(buffer, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected '%s' but got: %s\n", system_prompt, buffer); goto done; } - // Test: retrieving after setting prompt returns same text. - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('always reply yes');") != 0) goto done; - if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; - if (is_null || strcmp(buffer, "always reply yes") != 0) { - fprintf(stderr, "[get_system_prompt] expected 'always reply yes' but got: %s\n", buffer); + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) goto done; + + ai_chat_message_row rows[8]; + int count = 0; + if (fetch_ai_chat_messages(env, db, rows, 8, &count) != 0) goto done; + if (count < 3) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected at least 3 rows, got %d\n", count); goto done; } - - // Test: updating prompt replaces previous value. - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt('now reply no');") != 0) goto done; - if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; - if (is_null || strcmp(buffer, "now reply no") != 0) { - fprintf(stderr, "[get_system_prompt] expected 'now reply no' but got: %s\n", buffer); + if (!(rows[0].id < rows[1].id && rows[1].id < rows[2].id)) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected ascending ids but found %d, %d, %d\n", + rows[0].id, rows[1].id, rows[2].id); goto done; } - - // Test: setting prompt to NULL clears it. - if (exec_expect_ok(env, db, "SELECT llm_chat_system_prompt(NULL);") != 0) goto done; - if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; - if (!is_null) { - fprintf(stderr, "[get_system_prompt] expected NULL after clearing prompt, got: %s\n", buffer); + if (strcmp(rows[0].role, "system") != 0 || strcmp(rows[0].content, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] system row mismatch (%s, %s)\n", rows[0].role, rows[0].content); + goto done; + } + if (strcmp(rows[1].role, "user") != 0 || strcmp(rows[1].content, user_question) != 0) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] user row mismatch (%s, %s)\n", rows[1].role, rows[1].content); + goto done; + } + if (strcmp(rows[2].role, "assistant") != 0 || rows[2].content[0] == '\0') { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] assistant row mismatch (%s, %s)\n", rows[2].role, rows[2].content); goto done; } @@ -555,8 +623,9 @@ static const test_case TESTS[] = { {"issue15_llm_chat_without_context", test_issue15_chat_without_context}, {"llm_chat_respond_repeated", test_llm_chat_respond_repeated}, {"llm_chat_vtab", test_llm_chat_vtab}, - {"set_chat_system_prompt", test_set_chat_system_prompt}, - {"get_chat_system_prompt", test_get_chat_system_prompt}, + {"chat_system_prompt_new_chat", test_chat_system_prompt_new_chat}, + {"chat_system_prompt_replace_previous_prompt", test_chat_system_prompt_replace_previous_prompt}, + {"chat_system_prompt_after_first_response", test_chat_system_prompt_after_first_response}, }; int main(int argc, char **argv) {