Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 109 additions & 100 deletions src/obsidian_rag/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sqlite3
import struct
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence

Expand All @@ -30,7 +31,8 @@ def __init__(self, data_path: str, collection_name: str = "obsidian_notes"):
self.collection_name = collection_name
self.db_path = self.data_path / f"{collection_name}.db"

self.db = sqlite3.connect(str(self.db_path))
self._lock = threading.Lock()
self.db = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.db.enable_load_extension(True)
sqlite_vec.load(self.db)
self.db.enable_load_extension(False)
Expand Down Expand Up @@ -84,54 +86,57 @@ def _ensure_vec_table(self, dim: int) -> None:

def upsert(self, chunk: Chunk, embedding: List[float]) -> None:
"""Add or update a chunk."""
self._ensure_vec_table(len(embedding))
meta = self._prepare_metadata(chunk)

self.db.execute("""
INSERT OR REPLACE INTO chunks (id, file_path, heading, heading_level, type, tags, content)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (chunk.id, meta["file_path"], meta["heading"], meta["heading_level"],
meta["type"], meta.get("tags", ""), chunk.content))

# sqlite-vec: delete then insert (no native upsert on virtual tables)
self.db.execute("DELETE FROM chunks_vec WHERE id = ?", (chunk.id,))
self.db.execute(
"INSERT INTO chunks_vec (id, embedding) VALUES (?, ?)",
(chunk.id, _serialize_f32(embedding))
)
self.db.commit()

def upsert_batch(self, chunks: List[Chunk], embeddings: Sequence[Sequence[float]]) -> None:
"""Add or update multiple chunks."""
if not chunks:
return
self._ensure_vec_table(len(embeddings[0]))

for chunk, embedding in zip(chunks, embeddings):
with self._lock:
self._ensure_vec_table(len(embedding))
meta = self._prepare_metadata(chunk)

self.db.execute("""
INSERT OR REPLACE INTO chunks (id, file_path, heading, heading_level, type, tags, content)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (chunk.id, meta["file_path"], meta["heading"], meta["heading_level"],
meta["type"], meta.get("tags", ""), chunk.content))

# sqlite-vec: delete then insert (no native upsert on virtual tables)
self.db.execute("DELETE FROM chunks_vec WHERE id = ?", (chunk.id,))
self.db.execute(
"INSERT INTO chunks_vec (id, embedding) VALUES (?, ?)",
(chunk.id, _serialize_f32(embedding))
)
self.db.commit()

self.db.commit()
def upsert_batch(self, chunks: List[Chunk], embeddings: Sequence[Sequence[float]]) -> None:
"""Add or update multiple chunks."""
if not chunks:
return
with self._lock:
self._ensure_vec_table(len(embeddings[0]))

for chunk, embedding in zip(chunks, embeddings):
meta = self._prepare_metadata(chunk)
self.db.execute("""
INSERT OR REPLACE INTO chunks (id, file_path, heading, heading_level, type, tags, content)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (chunk.id, meta["file_path"], meta["heading"], meta["heading_level"],
meta["type"], meta.get("tags", ""), chunk.content))

self.db.execute("DELETE FROM chunks_vec WHERE id = ?", (chunk.id,))
self.db.execute(
"INSERT INTO chunks_vec (id, embedding) VALUES (?, ?)",
(chunk.id, _serialize_f32(embedding))
)

self.db.commit()

def delete_by_file(self, file_path: str) -> None:
"""Delete all chunks from a specific file."""
ids = [row[0] for row in
self.db.execute("SELECT id FROM chunks WHERE file_path = ?", (file_path,)).fetchall()]
if ids:
placeholders = ",".join("?" * len(ids))
self.db.execute(f"DELETE FROM chunks_vec WHERE id IN ({placeholders})", ids)
self.db.execute(f"DELETE FROM chunks WHERE id IN ({placeholders})", ids)
self.db.commit()
with self._lock:
ids = [row[0] for row in
self.db.execute("SELECT id FROM chunks WHERE file_path = ?", (file_path,)).fetchall()]
if ids:
placeholders = ",".join("?" * len(ids))
self.db.execute(f"DELETE FROM chunks_vec WHERE id IN ({placeholders})", ids)
self.db.execute(f"DELETE FROM chunks WHERE id IN ({placeholders})", ids)
self.db.commit()

def search(
self,
Expand All @@ -140,42 +145,64 @@ def search(
where: Optional[Dict] = None
) -> List[Dict]:
"""Search for similar chunks."""
if self._dim is None:
return []

query_bytes = _serialize_f32(query_embedding)

if where:
conditions = []
params: list = []
for key, value in where.items():
conditions.append(f"c.{key} = ?")
params.append(value)

where_clause = " AND ".join(conditions)
fetch_limit = limit * 5

rows = self.db.execute(f"""
SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance
FROM chunks_vec v
JOIN chunks c ON c.id = v.id
WHERE v.embedding MATCH ? AND k = ?
AND {where_clause}
ORDER BY v.distance
LIMIT ?
""", [query_bytes, fetch_limit] + params + [limit]).fetchall()
else:
with self._lock:
if self._dim is None:
return []

query_bytes = _serialize_f32(query_embedding)

if where:
conditions = []
params: list = []
for key, value in where.items():
conditions.append(f"c.{key} = ?")
params.append(value)

where_clause = " AND ".join(conditions)
fetch_limit = limit * 5

rows = self.db.execute(f"""
SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance
FROM chunks_vec v
JOIN chunks c ON c.id = v.id
WHERE v.embedding MATCH ? AND k = ?
AND {where_clause}
ORDER BY v.distance
LIMIT ?
""", [query_bytes, fetch_limit] + params + [limit]).fetchall()
else:
rows = self.db.execute("""
SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance
FROM chunks_vec v
JOIN chunks c ON c.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance
""", [query_bytes, limit]).fetchall()

results = []
for row in rows:
results.append({
"id": row[0],
"metadata": {
"file_path": row[1],
"heading": row[2] or "",
"heading_level": row[3],
"type": row[4] or "note",
"tags": row[5] or "",
},
"content": row[6],
"distance": row[7],
})
return results

def get_by_file(self, file_path: str) -> List[Dict]:
"""Get all chunks for a file path (direct lookup, no vector search)."""
with self._lock:
rows = self.db.execute("""
SELECT c.id, c.file_path, c.heading, c.heading_level, c.type, c.tags, c.content, v.distance
FROM chunks_vec v
JOIN chunks c ON c.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY v.distance
""", [query_bytes, limit]).fetchall()

results = []
for row in rows:
results.append({
SELECT id, file_path, heading, heading_level, type, tags, content
FROM chunks WHERE file_path = ?
""", (file_path,)).fetchall()
return [{
"id": row[0],
"metadata": {
"file_path": row[1],
Expand All @@ -185,44 +212,26 @@ def search(
"tags": row[5] or "",
},
"content": row[6],
"distance": row[7],
})
return results

def get_by_file(self, file_path: str) -> List[Dict]:
"""Get all chunks for a file path (direct lookup, no vector search)."""
rows = self.db.execute("""
SELECT id, file_path, heading, heading_level, type, tags, content
FROM chunks WHERE file_path = ?
""", (file_path,)).fetchall()
return [{
"id": row[0],
"metadata": {
"file_path": row[1],
"heading": row[2] or "",
"heading_level": row[3],
"type": row[4] or "note",
"tags": row[5] or "",
},
"content": row[6],
} for row in rows]
} for row in rows]

def get_stats(self) -> dict:
"""Get collection statistics."""
count = self.db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
return {
"collection": self.collection_name,
"count": count,
"data_path": str(self.data_path),
}
with self._lock:
count = self.db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
return {
"collection": self.collection_name,
"count": count,
"data_path": str(self.data_path),
}

def clear(self) -> None:
"""Clear all data."""
self.db.execute("DELETE FROM chunks")
if self._dim is not None:
self.db.execute("DROP TABLE IF EXISTS chunks_vec")
self._dim = None
self.db.commit()
with self._lock:
self.db.execute("DELETE FROM chunks")
if self._dim is not None:
self.db.execute("DROP TABLE IF EXISTS chunks_vec")
self._dim = None
self.db.commit()

def _prepare_metadata(self, chunk: Chunk) -> Dict[str, Any]:
"""Prepare metadata for storage."""
Expand Down
47 changes: 47 additions & 0 deletions tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for VectorStore - defines the contract for the sqlite-vec backend."""

import tempfile
import threading
from pathlib import Path

import pytest
Expand Down Expand Up @@ -224,3 +225,49 @@ def test_get_stats(self, store):
assert stats["count"] == 0
assert "data_path" in stats
assert "collection" in stats


class TestThreadSafety:
def test_upsert_from_different_thread(self, store):
"""VectorStore operations work when called from a non-creator thread."""
chunk = make_chunk("c1", "Hello from thread", "notes/thread.md")
embedding = [1.0, 0.0, 0.0, 0.0]
error = None

def worker():
nonlocal error
try:
store.upsert(chunk, embedding)
except Exception as e:
error = e

t = threading.Thread(target=worker)
t.start()
t.join()

assert error is None, f"Cross-thread upsert failed: {error}"
results = store.search([1.0, 0.0, 0.0, 0.0], limit=1)
assert len(results) == 1
assert results[0]["content"] == "Hello from thread"

def test_concurrent_upserts_from_multiple_threads(self, store):
"""Multiple threads can upsert without corruption."""
errors = []
num_threads = 5

def worker(i):
try:
chunk = make_chunk(f"c{i}", f"Note {i}", f"{i}.md")
embedding = [float(i == j) for j in range(DIM)]
store.upsert(chunk, embedding)
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_threads)]
for t in threads:
t.start()
for t in threads:
t.join()

assert errors == [], f"Concurrent upserts failed: {errors}"
assert store.get_stats()["count"] == num_threads