diff --git a/src/sqlite-ai.c b/src/sqlite-ai.c index f64f6a2..839244c 100644 --- a/src/sqlite-ai.c +++ b/src/sqlite-ai.c @@ -1957,20 +1957,44 @@ static void llm_chat_save (sqlite3_context *context, int argc, sqlite3_value **a // start transaction sqlite_db_write_simple(context, db, "BEGIN;"); - // save chat - const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?);"; + // save chat, the ON CONFLICT allows saving multiple times + const char *sql = "INSERT INTO ai_chat_history (uuid, title, metadata) VALUES (?, ?, ?) " + "ON CONFLICT(uuid) DO UPDATE SET " + " title = excluded.title, " + " metadata = excluded.metadata, " + " created_at = CURRENT_TIMESTAMP;"; const char *values[] = {ai->chat.uuid, title, meta}; int types[] = {SQLITE_TEXT, SQLITE_TEXT, SQLITE_TEXT}; int lens[] = {-1, -1, -1}; int rc = sqlite_db_write(context, db, sql, values, types, lens, 3); if (rc != SQLITE_OK) goto abort_save; - - // loop to save messages (the context) + + // get the rowid, cannot use sqlite3_last_insert_rowid for the CONFLICT case char rowid_s[256]; - sqlite3_int64 rowid = sqlite3_last_insert_rowid(db); + sqlite3_stmt *pstmt = NULL; + sql = "SELECT id FROM ai_chat_history WHERE uuid = ?;"; + rc = sqlite3_prepare_v2(db, sql, -1, &pstmt, NULL); + if (rc != SQLITE_OK) goto abort_save; + rc = sqlite3_bind_text(pstmt, 1, ai->chat.uuid, -1, SQLITE_STATIC); + rc = sqlite3_step(pstmt); + if (rc != SQLITE_ROW) { + sqlite3_finalize(pstmt); + goto abort_save; + } + sqlite3_int64 rowid = sqlite3_column_int64(pstmt, 0); + sqlite3_finalize(pstmt); snprintf(rowid_s, sizeof(rowid_s), "%lld", (long long)rowid); + + // delete all messages for this chat id, if any + sql = "DELETE FROM ai_chat_messages WHERE chat_id = ?;"; + const char *values3[] = {rowid_s}; + int types3[] = {SQLITE_INTEGER}; + int lens3[] = {-1}; + rc = sqlite_db_write(context, db, sql, values3, types3, lens3, 1); + if (rc != SQLITE_OK) goto abort_save; + // loop to save messages (the context) sql = "INSERT INTO ai_chat_messages (chat_id, role, content) VALUES (?, ?, ?);"; int types2[] = {SQLITE_INTEGER, SQLITE_TEXT, SQLITE_TEXT}; diff --git a/src/sqlite-ai.h b/src/sqlite-ai.h index 07ae4d3..e14851e 100644 --- a/src/sqlite-ai.h +++ b/src/sqlite-ai.h @@ -24,7 +24,7 @@ extern "C" { #endif -#define SQLITE_AI_VERSION "0.7.58" +#define SQLITE_AI_VERSION "0.7.59" SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi); diff --git a/tests/c/unittest.c b/tests/c/unittest.c index 7574d25..e0ebffe 100644 --- a/tests/c/unittest.c +++ b/tests/c/unittest.c @@ -1008,6 +1008,120 @@ static int test_chat_system_prompt_after_first_response(const test_env *env) { return status; } +static int test_llm_chat_double_save(const test_env *env) { + sqlite3 *db = NULL; + bool model_loaded = false; + bool context_created = false; + bool chat_created = false; + int status = 1; + + if (open_db_and_load(env, &db) != SQLITE_OK) { + goto done; + } + + const char *model = env->model_path ? env->model_path : DEFAULT_MODEL_PATH; + char sqlbuf[512]; + snprintf(sqlbuf, sizeof(sqlbuf), "SELECT llm_model_load('%s');", model); + if (exec_expect_ok(env, db, sqlbuf) != 0) + goto done; + model_loaded = true; + + if (exec_expect_ok(env, db, + "SELECT llm_context_create('context_size=1000');") != 0) + goto done; + context_created = true; + + if (exec_expect_ok(env, db, "SELECT llm_chat_create();") != 0) + goto done; + chat_created = true; + + // First prompt + const char *prompt1 = "First prompt"; + if (exec_expect_ok(env, db, "SELECT llm_chat_respond('First prompt');") != 0) + goto done; + + // First save + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) + goto done; + + // Second prompt + const char *prompt2 = "Second prompt"; + if (exec_expect_ok(env, db, "SELECT llm_chat_respond('Second prompt');") != 0) + goto done; + + // Second save + if (exec_expect_ok(env, db, "SELECT llm_chat_save();") != 0) + goto done; + + ai_chat_message_row rows[8]; + int count = 0; + // We expect 4 messages: User1, Assistant1, User2, Assistant2 + if (fetch_ai_chat_messages(env, db, rows, 8, &count) != 0) + goto done; + + if (count != 5) { + fprintf(stderr, + "[test_llm_chat_double_save] expected 4 message rows, got %d\n", + count); + goto done; + } + + // Verify order and roles + if (strcmp(rows[0].role, "system") != 0 || + strcmp(rows[0].content, "") != 0) { + fprintf(stderr, + "[test_llm_chat_double_save] row 0 mismatch (expected system/'%s', " + "got %s/'%s')\n", + "", rows[0].role, rows[0].content); + goto done; + } + if (strcmp(rows[1].role, "user") != 0 || + strcmp(rows[1].content, prompt1) != 0) { + fprintf(stderr, + "[test_llm_chat_double_save] row 0 mismatch (expected user/'%s', " + "got %s/'%s')\n", + prompt1, rows[1].role, rows[1].content); + goto done; + } + if (strcmp(rows[2].role, "assistant") != 0) { + fprintf(stderr, + "[test_llm_chat_double_save] row 1 mismatch (expected assistant, " + "got %s)\n", + rows[2].role); + goto done; + } + if (strcmp(rows[3].role, "user") != 0 || + strcmp(rows[3].content, prompt2) != 0) { + fprintf(stderr, + "[test_llm_chat_double_save] row 2 mismatch (expected user/'%s', " + "got %s/'%s')\n", + prompt2, rows[3].role, rows[3].content); + goto done; + } + if (strcmp(rows[4].role, "assistant") != 0) { + fprintf(stderr, + "[test_llm_chat_double_save] row 3 mismatch (expected assistant, " + "got %s)\n", + rows[4].role); + goto done; + } + + status = 0; + +done: + if (chat_created) + exec_expect_ok(env, db, "SELECT llm_chat_free();"); + if (context_created) + exec_expect_ok(env, db, "SELECT llm_context_free();"); + if (model_loaded) + exec_expect_ok(env, db, "SELECT llm_model_free();"); + if (db) + sqlite3_close(db); + if (status == 0) + status = assert_sqlite_memory_clean("llm_chat_double_save", env); + return status; +} + static const test_case TESTS[] = { {"issue15_llm_chat_without_context", test_issue15_chat_without_context}, {"llm_chat_respond_repeated", test_llm_chat_respond_repeated}, @@ -1026,6 +1140,7 @@ static const test_case TESTS[] = { {"chat_system_prompt_new_chat", test_chat_system_prompt_new_chat}, {"chat_system_prompt_replace_previous_prompt", test_chat_system_prompt_replace_previous_prompt}, {"chat_system_prompt_after_first_response", test_chat_system_prompt_after_first_response}, + {"llm_chat_double_save", test_llm_chat_double_save}, }; int main(int argc, char **argv) {