diff --git a/src/obsidian_rag/store.py b/src/obsidian_rag/store.py index ee7446f..b01e43d 100644 --- a/src/obsidian_rag/store.py +++ b/src/obsidian_rag/store.py @@ -4,6 +4,7 @@ import sqlite3 import struct +import threading from pathlib import Path from typing import Any, Dict, List, Optional, Sequence @@ -30,7 +31,8 @@ def __init__(self, data_path: str, collection_name: str = "obsidian_notes"): self.collection_name = collection_name self.db_path = self.data_path / f"{collection_name}.db" - self.db = sqlite3.connect(str(self.db_path)) + self._lock = threading.Lock() + self.db = sqlite3.connect(str(self.db_path), check_same_thread=False) self.db.enable_load_extension(True) sqlite_vec.load(self.db) self.db.enable_load_extension(False) @@ -84,54 +86,57 @@ def _ensure_vec_table(self, dim: int) -> None: def upsert(self, chunk: Chunk, embedding: List[float]) -> None: """Add or update a chunk.""" - self._ensure_vec_table(len(embedding)) - meta = self._prepare_metadata(chunk) - - self.db.execute(""" - INSERT OR REPLACE INTO chunks (id, file_path, heading, heading_level, type, tags, content) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, (chunk.id, meta["file_path"], meta["heading"], meta["heading_level"], - meta["type"], meta.get("tags", ""), chunk.content)) - - # sqlite-vec: delete then insert (no native upsert on virtual tables) - self.db.execute("DELETE FROM chunks_vec WHERE id = ?", (chunk.id,)) - self.db.execute( - "INSERT INTO chunks_vec (id, embedding) VALUES (?, ?)", - (chunk.id, _serialize_f32(embedding)) - ) - self.db.commit() - - def upsert_batch(self, chunks: List[Chunk], embeddings: Sequence[Sequence[float]]) -> None: - """Add or update multiple chunks.""" - if not chunks: - return - self._ensure_vec_table(len(embeddings[0])) - - for chunk, embedding in zip(chunks, embeddings): + with self._lock: + self._ensure_vec_table(len(embedding)) meta = self._prepare_metadata(chunk) + self.db.execute(""" INSERT OR REPLACE INTO chunks (id, file_path, heading, heading_level, type, tags, content) VALUES (?, ?, ?, ?, ?, ?, ?) """, (chunk.id, meta["file_path"], meta["heading"], meta["heading_level"], meta["type"], meta.get("tags", ""), chunk.content)) + # sqlite-vec: delete then insert (no native upsert on virtual tables) self.db.execute("DELETE FROM chunks_vec WHERE id = ?", (chunk.id,)) self.db.execute( "INSERT INTO chunks_vec (id, embedding) VALUES (?, ?)", (chunk.id, _serialize_f32(embedding)) ) + self.db.commit() - self.db.commit() + def upsert_batch(self, chunks: List[Chunk], embeddings: Sequence[Sequence[float]]) -> None: + """Add or update multiple chunks.""" + if not chunks: + return + with self._lock: + self._ensure_vec_table(len(embeddings[0])) + + for chunk, embedding in zip(chunks, embeddings): + meta = self._prepare_metadata(chunk) + self.db.execute(""" + INSERT OR REPLACE INTO chunks (id, file_path, heading, heading_level, type, tags, content) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, (chunk.id, meta["file_path"], meta["heading"], meta["heading_level"], + meta["type"], meta.get("tags", ""), chunk.content)) + + self.db.execute("DELETE FROM chunks_vec WHERE id = ?", (chunk.id,)) + self.db.execute( + "INSERT INTO chunks_vec (id, embedding) VALUES (?, ?)", + (chunk.id, _serialize_f32(embedding)) + ) + + self.db.commit() def delete_by_file(self, file_path: str) -> None: """Delete all chunks from a specific file.""" - ids = [row[0] for row in - self.db.execute("SELECT id FROM chunks WHERE file_path = ?", (file_path,)).fetchall()] - if ids: - placeholders = ",".join("?" * len(ids)) - self.db.execute(f"DELETE FROM chunks_vec WHERE id IN ({placeholders})", ids) - self.db.execute(f"DELETE FROM chunks WHERE id IN ({placeholders})", ids) - self.db.commit() + with self._lock: + ids = [row[0] for row in + self.db.execute("SELECT id FROM chunks WHERE file_path = ?", (file_path,)).fetchall()] + if ids: + placeholders = ",".join("?" * len(ids)) + self.db.execute(f"DELETE FROM chunks_vec WHERE id IN ({placeholders})", ids) + self.db.execute(f"DELETE FROM chunks WHERE id IN ({placeholders})", ids) + self.db.commit() def search( self, @@ -140,42 +145,64 @@ def search( where: Optional[Dict] = None ) -> List[Dict]: """Search for similar chunks.""" - if self._dim is None: - return [] - - query_bytes = _serialize_f32(query_embedding) - - if where: - conditions = [] - params: list = [] - for key, value in where.items(): - conditions.append(f"c.{key} = ?") - params.append(value) - - where_clause = " AND ".join(conditions) - fetch_limit = limit * 5 - - rows = self.db.execute(f""" - SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance - FROM chunks_vec v - JOIN chunks c ON c.id = v.id - WHERE v.embedding MATCH ? AND k = ? - AND {where_clause} - ORDER BY v.distance - LIMIT ? - """, [query_bytes, fetch_limit] + params + [limit]).fetchall() - else: + with self._lock: + if self._dim is None: + return [] + + query_bytes = _serialize_f32(query_embedding) + + if where: + conditions = [] + params: list = [] + for key, value in where.items(): + conditions.append(f"c.{key} = ?") + params.append(value) + + where_clause = " AND ".join(conditions) + fetch_limit = limit * 5 + + rows = self.db.execute(f""" + SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance + FROM chunks_vec v + JOIN chunks c ON c.id = v.id + WHERE v.embedding MATCH ? AND k = ? + AND {where_clause} + ORDER BY v.distance + LIMIT ? + """, [query_bytes, fetch_limit] + params + [limit]).fetchall() + else: + rows = self.db.execute(""" + SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance + FROM chunks_vec v + JOIN chunks c ON c.id = v.id + WHERE v.embedding MATCH ? AND k = ? + ORDER BY v.distance + """, [query_bytes, limit]).fetchall() + + results = [] + for row in rows: + results.append({ + "id": row[0], + "metadata": { + "file_path": row[1], + "heading": row[2] or "", + "heading_level": row[3], + "type": row[4] or "note", + "tags": row[5] or "", + }, + "content": row[6], + "distance": row[7], + }) + return results + + def get_by_file(self, file_path: str) -> List[Dict]: + """Get all chunks for a file path (direct lookup, no vector search).""" + with self._lock: rows = self.db.execute(""" - SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance - FROM chunks_vec v - JOIN chunks c ON c.id = v.id - WHERE v.embedding MATCH ? AND k = ? - ORDER BY v.distance - """, [query_bytes, limit]).fetchall() - - results = [] - for row in rows: - results.append({ + SELECT id, file_path, heading, heading_level, type, tags, content + FROM chunks WHERE file_path = ? + """, (file_path,)).fetchall() + return [{ "id": row[0], "metadata": { "file_path": row[1], @@ -185,44 +212,26 @@ def search( "tags": row[5] or "", }, "content": row[6], - "distance": row[7], - }) - return results - - def get_by_file(self, file_path: str) -> List[Dict]: - """Get all chunks for a file path (direct lookup, no vector search).""" - rows = self.db.execute(""" - SELECT id, file_path, heading, heading_level, type, tags, content - FROM chunks WHERE file_path = ? - """, (file_path,)).fetchall() - return [{ - "id": row[0], - "metadata": { - "file_path": row[1], - "heading": row[2] or "", - "heading_level": row[3], - "type": row[4] or "note", - "tags": row[5] or "", - }, - "content": row[6], - } for row in rows] + } for row in rows] def get_stats(self) -> dict: """Get collection statistics.""" - count = self.db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0] - return { - "collection": self.collection_name, - "count": count, - "data_path": str(self.data_path), - } + with self._lock: + count = self.db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0] + return { + "collection": self.collection_name, + "count": count, + "data_path": str(self.data_path), + } def clear(self) -> None: """Clear all data.""" - self.db.execute("DELETE FROM chunks") - if self._dim is not None: - self.db.execute("DROP TABLE IF EXISTS chunks_vec") - self._dim = None - self.db.commit() + with self._lock: + self.db.execute("DELETE FROM chunks") + if self._dim is not None: + self.db.execute("DROP TABLE IF EXISTS chunks_vec") + self._dim = None + self.db.commit() def _prepare_metadata(self, chunk: Chunk) -> Dict[str, Any]: """Prepare metadata for storage.""" diff --git a/tests/test_store.py b/tests/test_store.py index 3eea72a..c6d0aeb 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -1,6 +1,7 @@ """Tests for VectorStore - defines the contract for the sqlite-vec backend.""" import tempfile +import threading from pathlib import Path import pytest @@ -224,3 +225,49 @@ def test_get_stats(self, store): assert stats["count"] == 0 assert "data_path" in stats assert "collection" in stats + + +class TestThreadSafety: + def test_upsert_from_different_thread(self, store): + """VectorStore operations work when called from a non-creator thread.""" + chunk = make_chunk("c1", "Hello from thread", "notes/thread.md") + embedding = [1.0, 0.0, 0.0, 0.0] + error = None + + def worker(): + nonlocal error + try: + store.upsert(chunk, embedding) + except Exception as e: + error = e + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert error is None, f"Cross-thread upsert failed: {error}" + results = store.search([1.0, 0.0, 0.0, 0.0], limit=1) + assert len(results) == 1 + assert results[0]["content"] == "Hello from thread" + + def test_concurrent_upserts_from_multiple_threads(self, store): + """Multiple threads can upsert without corruption.""" + errors = [] + num_threads = 5 + + def worker(i): + try: + chunk = make_chunk(f"c{i}", f"Note {i}", f"{i}.md") + embedding = [float(i == j) for j in range(DIM)] + store.upsert(chunk, embedding) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [], f"Concurrent upserts failed: {errors}" + assert store.get_stats()["count"] == num_threads