diff --git a/pyproject.toml b/pyproject.toml index bbefc36..037d2c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,10 @@ dev = [ "duckdb", "sqlite-vec", ] +# Lightweight test-runner deps for CI (no backend extras — those live in +# `dev`). Installed in CI via [tool.wads.ci.install].extras so the async +# suite (tests/test_async.py) has pytest-asyncio. +test = ["pytest>=7.0", "pytest-cov>=4.0", "pytest-asyncio>=0.23"] docs = ["sphinx>=6.0", "sphinx-rtd-theme>=1.0"] [tool.hatch.build.targets.wheel] @@ -102,6 +106,12 @@ asyncio_mode = "auto" [tool.wads.ci] installer = "uv" +[tool.wads.ci.install] +# Install the `test` extra so CI has the async test runner (pytest-asyncio); +# without it tests/test_async.py errors with "async def functions are not +# natively supported" and every async test fails. +extras = "test" + [tool.wads.ci.testing] python_versions = ["3.10", "3.12"] pytest_args = ["-v", "--tb=short"] diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 23c7b66..5d85866 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -75,6 +75,50 @@ def test_bm25_filter_is_applied(): assert [h["id"] for h in hits] == ["a"] +# ---------- BM25Index: build-once / query-many ----------------------------- # + + +def _toy_collection(): + client = vd.connect("memory") + col = client.create_collection("bm25idx", dimension=2) + col["a"] = vd.Document(id="a", text="the quick brown fox", vector=[1.0, 0.0]) + col["b"] = vd.Document(id="b", text="lazy dog sleeps", vector=[0.0, 1.0]) + col["c"] = vd.Document(id="c", text="quick fox runs", vector=[0.5, 0.5]) + return col + + +def test_bm25_index_matches_one_shot_function(): + """A prebuilt BM25Index gives identical results to bm25_lexical_search.""" + col = _toy_collection() + index = vd.BM25Index(col) + for query in ("quick fox", "lazy dog", "runs", "nonexistentterm"): + one_shot = vd.bm25_lexical_search(col, query, limit=10) + from_index = index.search(query, limit=10) + assert [(r["id"], round(r["score"], 9)) for r in from_index] == [ + (r["id"], round(r["score"], 9)) for r in one_shot + ], f"mismatch for {query!r}" + + +def test_bm25_index_reusable_across_queries(): + """The same index answers many queries (build-once / query-many).""" + index = vd.BM25Index(_toy_collection()) + assert index.search("lazy dog", limit=1)[0]["id"] == "b" + assert index.search("quick fox", limit=2)[0]["id"] in ("a", "c") + assert len(index) == 3 + + +def test_bm25_index_empty_query_and_filter(): + col = _toy_collection() + assert vd.BM25Index(col).search("", limit=5) == [] + # filter applied once at build time → only matching docs are indexed + col["d"] = vd.Document( + id="d", text="quick fox flies", vector=[0.2, 0.8], metadata={"kind": "z"} + ) + index = vd.BM25Index(col, filter={"kind": "z"}) + assert len(index) == 1 + assert [r["id"] for r in index.search("quick fox", limit=10)] == ["d"] + + # ---------- Hybrid contract: parametrized over every reachable backend ----- # diff --git a/vd/__init__.py b/vd/__init__.py index 8ff185e..f96e388 100644 --- a/vd/__init__.py +++ b/vd/__init__.py @@ -173,6 +173,7 @@ def skills_dir() -> _Path: # ----- advanced search ----------------------------------------------------- # from vd.search import ( # noqa: E402 + BM25Index, bm25_lexical_search, deduplicate_results, hybrid_search, @@ -307,6 +308,7 @@ def skills_dir() -> _Path: "deduplicate_results", "hybrid_search", "bm25_lexical_search", + "BM25Index", # configuration "load_config", "save_config", diff --git a/vd/search.py b/vd/search.py index b88bb88..f60ded7 100644 --- a/vd/search.py +++ b/vd/search.py @@ -368,6 +368,132 @@ def _tokenize(text: str) -> list[str]: return _TOKEN_RE.findall(text.lower()) +class BM25Index: + """A reusable Okapi BM25 index over a vd collection's stored ``text``. + + Builds the **query-independent** term statistics — per-document token + lists, document frequencies, document lengths, and the mean length — once + in :meth:`__init__`, then answers many queries against them via + :meth:`search`. This is the build-once / query-many companion to + :func:`bm25_lexical_search` (which builds a throwaway index for a single + query): for batch evaluation or any repeated querying of the same + collection it turns an O(N · Q) workload (re-tokenizing every document on + every query) into O(N + scoring · Q). + + Construction is **O(N)** in the collection size; each :meth:`search` is + O(matching documents). Fine for prototypes and collections up to ~100k + documents; for larger workloads switch to a backend with a native text + index (weaviate, elasticsearch, redis, …). + + Parameters + ---------- + collection : Collection + Any vd Collection (or mapping-like ``id -> obj`` exposing ``.text`` and + ``.metadata``). Documents whose ``text`` is empty contribute zero score + and are dropped at build time. + filter : dict, optional + Canonical ``vd`` metadata filter, applied **once** at build time (via + :func:`vd.filters.matches_filter`) so the index covers only the + matching documents and its statistics reflect that subset. + tokenize : Callable[[str], list[str]], optional + Tokenizer (default: lowercased ``\\w+`` tokens). Pass a custom one for + stemming, CJK, etc. + + Examples + -------- + >>> import vd + >>> c = vd.connect('memory').create_collection('t', dimension=2) + >>> c['a'] = vd.Document(id='a', text='the quick brown fox', vector=[1.0, 0.0]) + >>> c['b'] = vd.Document(id='b', text='lazy dog sleeps', vector=[0.0, 1.0]) + >>> index = vd.BM25Index(c) + >>> index.search('quick fox', limit=1)[0]['id'] + 'a' + """ + + def __init__( + self, + collection: Collection, + *, + filter: Optional[dict] = None, + tokenize: Callable[[str], list[str]] = _tokenize, + ): + from vd.filters import matches_filter + + self._tokenize = tokenize + # Pass 1 (query-independent): tokenize once, compute document + # frequencies and lengths. + docs: list[tuple[str, str, dict, list[str]]] = [] + df: dict[str, int] = {} + for doc_id in collection: + doc = collection[doc_id] + if filter is not None and not matches_filter(doc.metadata or {}, filter): + continue + tokens = tokenize(doc.text or "") + if not tokens: + continue + docs.append((doc_id, doc.text, dict(doc.metadata or {}), tokens)) + for term in set(tokens): + df[term] = df.get(term, 0) + 1 + self._docs = docs + self._df = df + self._n_docs = len(docs) + self._avg_len = ( + sum(len(toks) for _, _, _, toks in docs) / self._n_docs if docs else 0.0 + ) + + def __len__(self) -> int: + """Number of non-empty, filter-surviving documents in the index.""" + return self._n_docs + + def search( + self, + query_text: str, + *, + limit: int = 10, + k1: float = 1.5, + b: float = 0.75, + ) -> list[SearchResult]: + """Okapi BM25 scores for ``query_text`` over the indexed documents. + + Returns result dicts in the same shape as :meth:`Collection.search` — + ``{"id", "text", "score", "metadata"}`` — sorted by descending score. + ``k1`` / ``b`` are the standard Okapi hyperparameters (scoring-time, so + one index can be queried with different settings). + """ + query_tokens = self._tokenize(query_text) + if not query_tokens or not self._docs: + return [] + + # Pass 2 (query-dependent): idf for the query terms, then Okapi scoring. + n_docs, avg_len, df = self._n_docs, self._avg_len, self._df + idf = { + term: math.log(1 + (n_docs - df[term] + 0.5) / (df[term] + 0.5)) + for term in set(query_tokens) + if term in df + } + + scored: list[SearchResult] = [] + for doc_id, text, metadata, tokens in self._docs: + doc_len = len(tokens) + tf: dict[str, int] = {} + for tok in tokens: + if tok in idf: + tf[tok] = tf.get(tok, 0) + 1 + if not tf: + continue + score = 0.0 + for term, freq in tf.items(): + numer = freq * (k1 + 1) + denom = freq + k1 * (1 - b + b * doc_len / avg_len) + score += idf[term] * numer / denom + scored.append( + {"id": doc_id, "text": text, "score": score, "metadata": metadata} + ) + + scored.sort(key=lambda r: r["score"], reverse=True) + return scored[:limit] + + def bm25_lexical_search( collection: Collection, query_text: str, @@ -380,16 +506,18 @@ def bm25_lexical_search( """ Brute-force BM25 lexical search over a vd collection's stored ``text``. - Iterates every document in ``collection``, tokenizes its ``text`` field, - and computes Okapi BM25 scores against ``query_text``. Used as the default - lexical side of :func:`hybrid_search` when a collection does not implement - :class:`SupportsHybrid`. + Builds a throwaway :class:`BM25Index` over ``collection`` and runs a single + query against it. Used as the default lexical side of :func:`hybrid_search` + when a collection does not implement :class:`SupportsHybrid`. Cost is **O(N)** in the collection size — fine for prototypes and - collections up to ~100k documents. For larger workloads, either switch to - a backend with native hybrid search (weaviate, elasticsearch, redis, …) - or pass a custom ``lexical_search`` callable to :func:`hybrid_search` that - consults a real text index. + collections up to ~100k documents. **For repeated queries over the same + collection, build a** :class:`BM25Index` **once and call** + :meth:`BM25Index.search` **per query** instead of calling this function in a + loop — the term statistics are then computed once rather than on every call. + For larger workloads, switch to a backend with native hybrid search + (weaviate, elasticsearch, redis, …) or pass a custom ``lexical_search`` + callable to :func:`hybrid_search` that consults a real text index. Parameters ---------- @@ -422,60 +550,9 @@ def bm25_lexical_search( >>> hits[0]['id'] 'a' """ - from vd.filters import matches_filter - - query_tokens = _tokenize(query_text) - if not query_tokens: - return [] - - # Pass 1: tokenize once, compute document frequencies and lengths. - docs: list[tuple[str, str, dict, list[str]]] = [] - df: dict[str, int] = {} - for doc_id in collection: - doc = collection[doc_id] - if filter is not None and not matches_filter(doc.metadata or {}, filter): - continue - tokens = _tokenize(doc.text or "") - if not tokens: - continue - docs.append((doc_id, doc.text, dict(doc.metadata or {}), tokens)) - for term in set(tokens): - df[term] = df.get(term, 0) + 1 - - if not docs: - return [] - - n_docs = len(docs) - avg_len = sum(len(toks) for _, _, _, toks in docs) / n_docs - - # Pass 2: BM25 scoring (Okapi). - query_terms = set(query_tokens) - idf = { - term: math.log(1 + (n_docs - df[term] + 0.5) / (df[term] + 0.5)) - for term in query_terms - if term in df - } - - scored: list[SearchResult] = [] - for doc_id, text, metadata, tokens in docs: - score = 0.0 - doc_len = len(tokens) - tf: dict[str, int] = {} - for tok in tokens: - if tok in idf: - tf[tok] = tf.get(tok, 0) + 1 - if not tf: - continue - for term, freq in tf.items(): - numer = freq * (k1 + 1) - denom = freq + k1 * (1 - b + b * doc_len / avg_len) - score += idf[term] * numer / denom - scored.append( - {"id": doc_id, "text": text, "score": score, "metadata": metadata} - ) - - scored.sort(key=lambda r: r["score"], reverse=True) - return scored[:limit] + return BM25Index(collection, filter=filter).search( + query_text, limit=limit, k1=k1, b=b + ) def _rrf_fuse(