Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ dev = [
"duckdb",
"sqlite-vec",
]
# Lightweight test-runner deps for CI (no backend extras — those live in
# `dev`). Installed in CI via [tool.wads.ci.install].extras so the async
# suite (tests/test_async.py) has pytest-asyncio.
test = ["pytest>=7.0", "pytest-cov>=4.0", "pytest-asyncio>=0.23"]
docs = ["sphinx>=6.0", "sphinx-rtd-theme>=1.0"]

[tool.hatch.build.targets.wheel]
Expand Down Expand Up @@ -102,6 +106,12 @@ asyncio_mode = "auto"
[tool.wads.ci]
installer = "uv"

[tool.wads.ci.install]
# Install the `test` extra so CI has the async test runner (pytest-asyncio);
# without it tests/test_async.py errors with "async def functions are not
# natively supported" and every async test fails.
extras = "test"

[tool.wads.ci.testing]
python_versions = ["3.10", "3.12"]
pytest_args = ["-v", "--tb=short"]
Expand Down
44 changes: 44 additions & 0 deletions tests/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,50 @@ def test_bm25_filter_is_applied():
assert [h["id"] for h in hits] == ["a"]


# ---------- BM25Index: build-once / query-many ----------------------------- #


def _toy_collection():
client = vd.connect("memory")
col = client.create_collection("bm25idx", dimension=2)
col["a"] = vd.Document(id="a", text="the quick brown fox", vector=[1.0, 0.0])
col["b"] = vd.Document(id="b", text="lazy dog sleeps", vector=[0.0, 1.0])
col["c"] = vd.Document(id="c", text="quick fox runs", vector=[0.5, 0.5])
return col


def test_bm25_index_matches_one_shot_function():
"""A prebuilt BM25Index gives identical results to bm25_lexical_search."""
col = _toy_collection()
index = vd.BM25Index(col)
for query in ("quick fox", "lazy dog", "runs", "nonexistentterm"):
one_shot = vd.bm25_lexical_search(col, query, limit=10)
from_index = index.search(query, limit=10)
assert [(r["id"], round(r["score"], 9)) for r in from_index] == [
(r["id"], round(r["score"], 9)) for r in one_shot
], f"mismatch for {query!r}"


def test_bm25_index_reusable_across_queries():
"""The same index answers many queries (build-once / query-many)."""
index = vd.BM25Index(_toy_collection())
assert index.search("lazy dog", limit=1)[0]["id"] == "b"
assert index.search("quick fox", limit=2)[0]["id"] in ("a", "c")
assert len(index) == 3


def test_bm25_index_empty_query_and_filter():
col = _toy_collection()
assert vd.BM25Index(col).search("", limit=5) == []
# filter applied once at build time → only matching docs are indexed
col["d"] = vd.Document(
id="d", text="quick fox flies", vector=[0.2, 0.8], metadata={"kind": "z"}
)
index = vd.BM25Index(col, filter={"kind": "z"})
assert len(index) == 1
assert [r["id"] for r in index.search("quick fox", limit=10)] == ["d"]


# ---------- Hybrid contract: parametrized over every reachable backend ----- #


Expand Down
2 changes: 2 additions & 0 deletions vd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def skills_dir() -> _Path:

# ----- advanced search ----------------------------------------------------- #
from vd.search import ( # noqa: E402
BM25Index,
bm25_lexical_search,
deduplicate_results,
hybrid_search,
Expand Down Expand Up @@ -307,6 +308,7 @@ def skills_dir() -> _Path:
"deduplicate_results",
"hybrid_search",
"bm25_lexical_search",
"BM25Index",
# configuration
"load_config",
"save_config",
Expand Down
201 changes: 139 additions & 62 deletions vd/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,132 @@ def _tokenize(text: str) -> list[str]:
return _TOKEN_RE.findall(text.lower())


class BM25Index:
"""A reusable Okapi BM25 index over a vd collection's stored ``text``.

Builds the **query-independent** term statistics — per-document token
lists, document frequencies, document lengths, and the mean length — once
in :meth:`__init__`, then answers many queries against them via
:meth:`search`. This is the build-once / query-many companion to
:func:`bm25_lexical_search` (which builds a throwaway index for a single
query): for batch evaluation or any repeated querying of the same
collection it turns an O(N · Q) workload (re-tokenizing every document on
every query) into O(N + scoring · Q).

Construction is **O(N)** in the collection size; each :meth:`search` is
O(matching documents). Fine for prototypes and collections up to ~100k
documents; for larger workloads switch to a backend with a native text
index (weaviate, elasticsearch, redis, …).

Parameters
----------
collection : Collection
Any vd Collection (or mapping-like ``id -> obj`` exposing ``.text`` and
``.metadata``). Documents whose ``text`` is empty contribute zero score
and are dropped at build time.
filter : dict, optional
Canonical ``vd`` metadata filter, applied **once** at build time (via
:func:`vd.filters.matches_filter`) so the index covers only the
matching documents and its statistics reflect that subset.
tokenize : Callable[[str], list[str]], optional
Tokenizer (default: lowercased ``\\w+`` tokens). Pass a custom one for
stemming, CJK, etc.

Examples
--------
>>> import vd
>>> c = vd.connect('memory').create_collection('t', dimension=2)
>>> c['a'] = vd.Document(id='a', text='the quick brown fox', vector=[1.0, 0.0])
>>> c['b'] = vd.Document(id='b', text='lazy dog sleeps', vector=[0.0, 1.0])
>>> index = vd.BM25Index(c)
>>> index.search('quick fox', limit=1)[0]['id']
'a'
"""

def __init__(
self,
collection: Collection,
*,
filter: Optional[dict] = None,
tokenize: Callable[[str], list[str]] = _tokenize,
):
from vd.filters import matches_filter

self._tokenize = tokenize
# Pass 1 (query-independent): tokenize once, compute document
# frequencies and lengths.
docs: list[tuple[str, str, dict, list[str]]] = []
df: dict[str, int] = {}
for doc_id in collection:
doc = collection[doc_id]
if filter is not None and not matches_filter(doc.metadata or {}, filter):
continue
tokens = tokenize(doc.text or "")
if not tokens:
continue
docs.append((doc_id, doc.text, dict(doc.metadata or {}), tokens))
for term in set(tokens):
df[term] = df.get(term, 0) + 1
self._docs = docs
self._df = df
self._n_docs = len(docs)
self._avg_len = (
sum(len(toks) for _, _, _, toks in docs) / self._n_docs if docs else 0.0
)

def __len__(self) -> int:
"""Number of non-empty, filter-surviving documents in the index."""
return self._n_docs

def search(
self,
query_text: str,
*,
limit: int = 10,
k1: float = 1.5,
b: float = 0.75,
) -> list[SearchResult]:
"""Okapi BM25 scores for ``query_text`` over the indexed documents.

Returns result dicts in the same shape as :meth:`Collection.search` —
``{"id", "text", "score", "metadata"}`` — sorted by descending score.
``k1`` / ``b`` are the standard Okapi hyperparameters (scoring-time, so
one index can be queried with different settings).
"""
query_tokens = self._tokenize(query_text)
if not query_tokens or not self._docs:
return []

# Pass 2 (query-dependent): idf for the query terms, then Okapi scoring.
n_docs, avg_len, df = self._n_docs, self._avg_len, self._df
idf = {
term: math.log(1 + (n_docs - df[term] + 0.5) / (df[term] + 0.5))
for term in set(query_tokens)
if term in df
}

scored: list[SearchResult] = []
for doc_id, text, metadata, tokens in self._docs:
doc_len = len(tokens)
tf: dict[str, int] = {}
for tok in tokens:
if tok in idf:
tf[tok] = tf.get(tok, 0) + 1
if not tf:
continue
score = 0.0
for term, freq in tf.items():
numer = freq * (k1 + 1)
denom = freq + k1 * (1 - b + b * doc_len / avg_len)
score += idf[term] * numer / denom
scored.append(
{"id": doc_id, "text": text, "score": score, "metadata": metadata}
)

scored.sort(key=lambda r: r["score"], reverse=True)
return scored[:limit]


def bm25_lexical_search(
collection: Collection,
query_text: str,
Expand All @@ -380,16 +506,18 @@ def bm25_lexical_search(
"""
Brute-force BM25 lexical search over a vd collection's stored ``text``.

Iterates every document in ``collection``, tokenizes its ``text`` field,
and computes Okapi BM25 scores against ``query_text``. Used as the default
lexical side of :func:`hybrid_search` when a collection does not implement
:class:`SupportsHybrid`.
Builds a throwaway :class:`BM25Index` over ``collection`` and runs a single
query against it. Used as the default lexical side of :func:`hybrid_search`
when a collection does not implement :class:`SupportsHybrid`.

Cost is **O(N)** in the collection size — fine for prototypes and
collections up to ~100k documents. For larger workloads, either switch to
a backend with native hybrid search (weaviate, elasticsearch, redis, …)
or pass a custom ``lexical_search`` callable to :func:`hybrid_search` that
consults a real text index.
collections up to ~100k documents. **For repeated queries over the same
collection, build a** :class:`BM25Index` **once and call**
:meth:`BM25Index.search` **per query** instead of calling this function in a
loop — the term statistics are then computed once rather than on every call.
For larger workloads, switch to a backend with native hybrid search
(weaviate, elasticsearch, redis, …) or pass a custom ``lexical_search``
callable to :func:`hybrid_search` that consults a real text index.

Parameters
----------
Expand Down Expand Up @@ -422,60 +550,9 @@ def bm25_lexical_search(
>>> hits[0]['id']
'a'
"""
from vd.filters import matches_filter

query_tokens = _tokenize(query_text)
if not query_tokens:
return []

# Pass 1: tokenize once, compute document frequencies and lengths.
docs: list[tuple[str, str, dict, list[str]]] = []
df: dict[str, int] = {}
for doc_id in collection:
doc = collection[doc_id]
if filter is not None and not matches_filter(doc.metadata or {}, filter):
continue
tokens = _tokenize(doc.text or "")
if not tokens:
continue
docs.append((doc_id, doc.text, dict(doc.metadata or {}), tokens))
for term in set(tokens):
df[term] = df.get(term, 0) + 1

if not docs:
return []

n_docs = len(docs)
avg_len = sum(len(toks) for _, _, _, toks in docs) / n_docs

# Pass 2: BM25 scoring (Okapi).
query_terms = set(query_tokens)
idf = {
term: math.log(1 + (n_docs - df[term] + 0.5) / (df[term] + 0.5))
for term in query_terms
if term in df
}

scored: list[SearchResult] = []
for doc_id, text, metadata, tokens in docs:
score = 0.0
doc_len = len(tokens)
tf: dict[str, int] = {}
for tok in tokens:
if tok in idf:
tf[tok] = tf.get(tok, 0) + 1
if not tf:
continue
for term, freq in tf.items():
numer = freq * (k1 + 1)
denom = freq + k1 * (1 - b + b * doc_len / avg_len)
score += idf[term] * numer / denom
scored.append(
{"id": doc_id, "text": text, "score": score, "metadata": metadata}
)

scored.sort(key=lambda r: r["score"], reverse=True)
return scored[:limit]
return BM25Index(collection, filter=filter).search(
query_text, limit=limit, k1=k1, b=b
)


def _rrf_fuse(
Expand Down
Loading