diff --git a/src/sqlite-ai.c b/src/sqlite-ai.c index 8cf0e90..f64f6a2 100644 --- a/src/sqlite-ai.c +++ b/src/sqlite-ai.c @@ -199,6 +199,7 @@ typedef enum { AI_MODEL_CHAT_TEMPLATE } ai_model_setting; +const char *ROLE_SYSTEM = "system"; const char *ROLE_USER = "user"; const char *ROLE_ASSISTANT = "assistant"; @@ -785,27 +786,58 @@ static bool llm_check_context (sqlite3_context *context) { // MARK: - Chat Messages - bool llm_messages_append (ai_messages *list, const char *role, const char *content) { - if (list->count >= list->capacity) { + if (role == ROLE_SYSTEM && list->count > 0) { + // only one system prompt allowed at the beginning + return false; + } + + bool needs_system_message = (list->count == 0 && role != ROLE_SYSTEM); + size_t required = list->count + (needs_system_message ? 1 : 0); + if (required >= list->capacity) { size_t new_cap = list->capacity ? list->capacity * 2 : MIN_ALLOC_MESSAGES; llama_chat_message *new_items = sqlite3_realloc64(list->items, new_cap * sizeof(llama_chat_message)); if (!new_items) return false; - + list->items = new_items; list->capacity = new_cap; } - bool duplicate_role = ((role != ROLE_USER) && (role != ROLE_ASSISTANT)); + if (needs_system_message) { + // reserve first item for empty system prompt + list->items[list->count].role = ROLE_SYSTEM; + list->items[list->count].content = sqlite_strdup(""); + list->count += 1; + } + + bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); list->items[list->count].role = (duplicate_role) ? sqlite_strdup(role) : role; list->items[list->count].content = sqlite_strdup(content); list->count += 1; return true; } +bool llm_messages_set (ai_messages *list, int pos, const char *role, const char *content) { + if (pos < 0 || pos >= list->count) + return false; + + bool duplicate_role = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); + llama_chat_message *message = &list->items[pos]; + + const char *message_role = message->role; + if ((message_role != ROLE_SYSTEM) && (message_role != ROLE_USER) && (message_role != ROLE_ASSISTANT)) + sqlite3_free((char *)message_role); + sqlite3_free((char *)message->content); + + message->role = (duplicate_role) ? sqlite_strdup(role) : role; + message->content = sqlite_strdup(content); + return true; +} + void llm_messages_free (ai_messages *list) { for (size_t i = 0; i < list->count; ++i) { // check if rule is static const char *role = list->items[i].role; - bool role_tofree = ((role != ROLE_USER) && (role != ROLE_ASSISTANT)); + bool role_tofree = ((role != ROLE_SYSTEM) && (role != ROLE_USER) && (role != ROLE_ASSISTANT)); if (role_tofree) sqlite3_free((char *)list->items[i].role); // content is always to free sqlite3_free((char *)list->items[i].content); @@ -1648,12 +1680,23 @@ static bool llm_chat_run (ai_context *ai, ai_cursor *c, const char *user_prompt) return false; } + // skip empty system message if present + size_t messages_count = messages->count; + const llama_chat_message *messages_items = messages->items; + if (messages->count > 0) { + const llama_chat_message first_message = messages->items[0]; + if (first_message.role == ROLE_SYSTEM && first_message.content[0] == '\0') { + messages_items = messages->items + 1; + messages_count = messages->count - 1; + } + } + // transform a list of messages (the context) into // <|user|>What is AI?<|end|><|assistant|>AI stands for Artificial Intelligence...<|end|><|user|>Can you give an example?<|end|><|assistant|>... - int32_t new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity); + int32_t new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity); if (new_len > formatted->capacity) { if (buffer_resize(formatted, new_len * 2) == false) return false; - new_len = llama_chat_apply_template(template, messages->items, messages->count, true, formatted->data, formatted->capacity); + new_len = llama_chat_apply_template(template, messages_items, messages_count, true, formatted->data, formatted->capacity); } if ((new_len < 0) || (new_len > formatted->capacity)) { sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "failed to apply chat template"); @@ -2015,6 +2058,52 @@ static void llm_chat_respond (sqlite3_context *context, int argc, sqlite3_value llm_chat_run(ai, NULL, user_prompt); } +static void llm_chat_system_prompt(sqlite3_context *context, int argc, sqlite3_value **argv) { + if (llm_check_context(context) == false) + return; + + ai_context *ai = (ai_context *)sqlite3_user_data(context); + if (llm_chat_check_context(ai) == false) + return; + + ai_messages *messages = &ai->chat.messages; + + // get system role message + if (argc == 0) { + if (messages->count == 0) { + sqlite3_result_null(context); + return; + } + + // only the first message is reserved to the system role + llama_chat_message *system_message = &messages->items[0]; + const char *content = system_message->content; + if (system_message->role == ROLE_SYSTEM && content && content[0] != '\0') { + sqlite3_result_text(context, content, -1, SQLITE_TRANSIENT); + } else { + sqlite3_result_null(context); + } + + return; + } + + bool is_null_prompt = (sqlite3_value_type(argv[0]) == SQLITE_NULL); + int types[1]; + types[0] = is_null_prompt ? SQLITE_NULL : SQLITE_TEXT; + + if (sqlite_sanity_function(context, "llm_chat_system_prompt", argc, argv, 1, types, true, false) == false) + return; + + const unsigned char *prompt_text = sqlite3_value_text(argv[0]); + const char *system_prompt = prompt_text ? (const char *)prompt_text : ""; + if (!llm_messages_set(messages, 0, ROLE_SYSTEM, system_prompt)) { + if (!llm_messages_append(messages, ROLE_SYSTEM, system_prompt)) { + sqlite_common_set_error (ai->context, ai->vtab, SQLITE_ERROR, "Failed to set chat system prompt"); + return; + } + } +} + // MARK: - LLM Sampler - static void llm_sampler_init_greedy (sqlite3_context *context, int argc, sqlite3_value **argv) { @@ -2853,6 +2942,12 @@ SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_a rc = sqlite3_create_function(db, "llm_chat_respond", 1, SQLITE_UTF8, ctx, llm_chat_respond, NULL, NULL); if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_create_function(db, "llm_chat_system_prompt", 0, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; + + rc = sqlite3_create_function(db, "llm_chat_system_prompt", 1, SQLITE_UTF8, ctx, llm_chat_system_prompt, NULL, NULL); + if (rc != SQLITE_OK) goto cleanup; rc = sqlite3_create_module(db, "llm_chat", &llm_chat, ctx); if (rc != SQLITE_OK) goto cleanup; diff --git a/tests/c/unittest.c b/tests/c/unittest.c index bdc52fb..9756d95 100644 --- a/tests/c/unittest.c +++ b/tests/c/unittest.c @@ -157,6 +157,121 @@ static int exec_select_rows(const test_env *env, sqlite3 *db, const char *sql, i return 0; } +static int exec_query_text(const test_env *env, sqlite3 *db, const char *sql, char *text_out, size_t text_out_len) { + if (env->verbose) { + printf("[SQL] %s\n", sql); + } + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + fprintf(stderr, "sqlite3_prepare_v2 failed (%d): %s\n", rc, sqlite3_errmsg(db)); + if (stmt) sqlite3_finalize(stmt); + return 1; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + fprintf(stderr, "Expected a row for query: %s (rc=%d)\n", sql, rc); + sqlite3_finalize(stmt); + return 1; + } + const unsigned char *text = sqlite3_column_text(stmt, 0); + if (text_out && text_out_len > 0) { + if (text) { + snprintf(text_out, text_out_len, "%s", (const char *)text); + } else { + text_out[0] = '\0'; + } + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_DONE) { + fprintf(stderr, "Unexpected extra rows for query: %s\n", sql); + sqlite3_finalize(stmt); + return 1; + } + sqlite3_finalize(stmt); + return 0; +} + +static int query_system_prompt(const test_env *env, sqlite3 *db, char *buffer, size_t buffer_len, bool *is_null) { + const char *sql = "SELECT llm_chat_system_prompt();"; + if (env->verbose) { + printf("[SQL] %s\n", sql); + } + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + fprintf(stderr, "sqlite3_prepare_v2 failed (%d): %s\n", rc, sqlite3_errmsg(db)); + if (stmt) sqlite3_finalize(stmt); + return 1; + } + rc = sqlite3_step(stmt); + if (rc != SQLITE_ROW) { + fprintf(stderr, "Expected a row for query: %s (rc=%d)\n", sql, rc); + sqlite3_finalize(stmt); + return 1; + } + if (sqlite3_column_type(stmt, 0) == SQLITE_NULL) { + if (is_null) *is_null = true; + if (buffer && buffer_len) buffer[0] = '\0'; + } else { + if (is_null) *is_null = false; + const unsigned char *text = sqlite3_column_text(stmt, 0); + if (buffer && buffer_len) { + snprintf(buffer, buffer_len, "%s", text ? (const char *)text : ""); + } + } + sqlite3_finalize(stmt); + return 0; +} + +typedef struct { + int id; + int chat_id; + char role[32]; + char content[4096]; +} ai_chat_message_row; + +static int fetch_ai_chat_messages(const test_env *env, sqlite3 *db, ai_chat_message_row *rows, size_t max_rows, int *count_out) { + const char *sql = "SELECT * FROM ai_chat_messages ORDER BY id ASC;"; + if (env->verbose) { + printf("[SQL] %s\n", sql); + } + + sqlite3_stmt *stmt = NULL; + int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (rc != SQLITE_OK) { + fprintf(stderr, "sqlite3_prepare_v2 failed (%d): %s\n", rc, sqlite3_errmsg(db)); + if (stmt) sqlite3_finalize(stmt); + return 1; + } + + int count = 0; + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + if (rows && (size_t)count < max_rows) { + rows[count].id = sqlite3_column_int(stmt, 0); + rows[count].chat_id = sqlite3_column_int(stmt, 1); + const unsigned char *role = sqlite3_column_text(stmt, 2); + const unsigned char *content = sqlite3_column_text(stmt, 3); + snprintf(rows[count].role, sizeof(rows[count].role), "%s", role ? (const char *)role : ""); + snprintf(rows[count].content, sizeof(rows[count].content), "%s", content ? (const char *)content : ""); + } + count++; + } + if (rc != SQLITE_DONE) { + fprintf(stderr, "sqlite3_step failed (%d): %s\n", rc, sqlite3_errmsg(db)); + sqlite3_finalize(stmt); + return 1; + } + sqlite3_finalize(stmt); + + if (rows && (size_t)count > max_rows) { + fprintf(stderr, "Expected at most %zu messages but found %d\n", max_rows, count); + return 1; + } + if (count_out) *count_out = count; + return 0; +} + // --------------------------------------------------------------------- // Tests // --------------------------------------------------------------------- @@ -292,10 +407,225 @@ static int test_llm_chat_vtab(const test_env *env) { return 1; } +static int query_chat_response(const test_env *env, sqlite3 *db, const char *question, char *response, size_t response_len) { + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_respond('%s');", question); + return exec_query_text(env, db, sqlbuf, response, response_len); +} + +static int test_chat_system_prompt_new_chat(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; + } + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + model_loaded = true; + + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) goto done; + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + chat_created = true; + + const char *system_prompt = "Always reply with lowercase answers."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", system_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + + bool is_null = false; + char buffer[4096]; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected '%s' but got: %s\n", system_prompt, buffer); + goto done; + } + + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) goto done; + + ai_chat_message_row rows[4]; + int count = 0; + if (fetch_ai_chat_messages(env, db, rows, 4, &count) != 0) goto done; + if (count != 1) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected 1 message row, got %d\n", count); + goto done; + } + if (strcmp(rows[0].role, "system") != 0) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected system role, got %s\n", rows[0].role); + goto done; + } + if (strcmp(rows[0].content, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_new_chat] expected content '%s' but got '%s'\n", system_prompt, rows[0].content); + goto done; + } + + status = 0; + +done: + if (chat_created) exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (db) sqlite3_close(db); + return status; +} + +static int test_chat_system_prompt_replace_previous_prompt(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; + } + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + model_loaded = true; + + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) goto done; + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + chat_created = true; + + const char *first_prompt = "Always confirm questions."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", first_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + + const char *replacement_prompt = "Always decline questions."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", replacement_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + + bool is_null = false; + char buffer[4096]; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, replacement_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_replace_previous_prompt] expected '%s' but got: %s\n", replacement_prompt, buffer); + goto done; + } + + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) goto done; + + ai_chat_message_row rows[4]; + int count = 0; + if (fetch_ai_chat_messages(env, db, rows, 4, &count) != 0) goto done; + if (count != 1) { + fprintf(stderr, "[test_chat_system_prompt_replace_previous_prompt] expected 1 message row, got %d\n", count); + goto done; + } + if (strcmp(rows[0].content, replacement_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_replace_previous_prompt] expected '%s' but got '%s'\n", replacement_prompt, rows[0].content); + goto done; + } + + status = 0; + +done: + if (chat_created) exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (db) sqlite3_close(db); + return status; +} + +static int test_chat_system_prompt_after_first_response(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; + } + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + model_loaded = true; + + if (exec_expect_ok(env, db, "SELECT llm_context_create('context_size=1000');") != 0) goto done; + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) goto done; + chat_created = true; + + const char *user_question = "Reply to this ping."; + char response[4096]; + if (query_chat_response(env, db, user_question, response, sizeof(response)) != 0) goto done; + if (response[0] == '\0') { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected model response for '%s'\n", user_question); + goto done; + } + + const char *system_prompt = "Only answer with short confirmations."; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_chat_system_prompt('%s');", system_prompt); + if (exec_expect_ok(env, db, sqlbuf) != 0) goto done; + + bool is_null = false; + char buffer[4096]; + if (query_system_prompt(env, db, buffer, sizeof(buffer), &is_null) != 0) goto done; + if (is_null || strcmp(buffer, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected '%s' but got: %s\n", system_prompt, buffer); + goto done; + } + + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) goto done; + + ai_chat_message_row rows[8]; + int count = 0; + if (fetch_ai_chat_messages(env, db, rows, 8, &count) != 0) goto done; + if (count < 3) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected at least 3 rows, got %d\n", count); + goto done; + } + if (!(rows[0].id < rows[1].id && rows[1].id < rows[2].id)) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] expected ascending ids but found %d, %d, %d\n", + rows[0].id, rows[1].id, rows[2].id); + goto done; + } + if (strcmp(rows[0].role, "system") != 0 || strcmp(rows[0].content, system_prompt) != 0) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] system row mismatch (%s, %s)\n", rows[0].role, rows[0].content); + goto done; + } + if (strcmp(rows[1].role, "user") != 0 || strcmp(rows[1].content, user_question) != 0) { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] user row mismatch (%s, %s)\n", rows[1].role, rows[1].content); + goto done; + } + if (strcmp(rows[2].role, "assistant") != 0 || rows[2].content[0] == '\0') { + fprintf(stderr, "[test_chat_system_prompt_after_first_response] assistant row mismatch (%s, %s)\n", rows[2].role, rows[2].content); + goto done; + } + + status = 0; + +done: + if (chat_created) exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (db) sqlite3_close(db); + return status; +} + static const test_case TESTS[] = { {"issue15_llm_chat_without_context", test_issue15_chat_without_context}, {"llm_chat_respond_repeated", test_llm_chat_respond_repeated}, {"llm_chat_vtab", test_llm_chat_vtab}, + {"chat_system_prompt_new_chat", test_chat_system_prompt_new_chat}, + {"chat_system_prompt_replace_previous_prompt", test_chat_system_prompt_replace_previous_prompt}, + {"chat_system_prompt_after_first_response", test_chat_system_prompt_after_first_response}, }; int main(int argc, char **argv) {