Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,39 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]:
return [doc for doc, _ in docs_and_scores]


def _validate_lengths_match(
texts: Sequence[str],
embeddings: List[List[float]],
metadatas: Optional[List[dict]],
ids: Optional[List[str]],
) -> None:
"""Raise `ValueError` if `embeddings`/`metadatas`/`ids` lengths disagree
with `texts`.

`PGVector.add_embeddings`/`aadd_embeddings` previously built the insert
payload via `zip(texts, metadatas, embeddings, ids_)` which silently
truncates to the shortest argument and then returns the full `ids_`
list — so a caller passing N texts but receiving only M embeddings
saw N IDs returned but only M rows persisted, with no error.
"""
n = len(texts)
if len(embeddings) != n:
msg = (
f"Got {n} texts but {len(embeddings)} embeddings; "
"every text must have a matching embedding."
)
raise ValueError(msg)
if metadatas is not None and len(metadatas) != n:
msg = (
f"Got {n} texts but {len(metadatas)} metadatas; "
"every text must have a matching metadata entry."
)
raise ValueError(msg)
if ids is not None and len(ids) != n:
msg = f"Got {n} texts but {len(ids)} ids; every text must have a matching id."
raise ValueError(msg)


def _create_vector_extension(conn: Connection) -> None:
statement = sqlalchemy.text(
"SELECT pg_advisory_xact_lock(1573678846307946496);"
Expand Down Expand Up @@ -763,6 +796,7 @@ def add_embeddings(
kwargs: vectorstore specific parameters
"""
assert not self._async_engine, "This method must be called with sync_mode"
_validate_lengths_match(texts, embeddings, metadatas, ids)
if ids is None:
ids_ = [str(uuid.uuid4()) for _ in texts]
else:
Expand All @@ -784,7 +818,7 @@ def add_embeddings(
"cmetadata": metadata or {},
}
for text, metadata, embedding, id in zip(
texts, metadatas, embeddings, ids_
texts, metadatas, embeddings, ids_, strict=True
)
]
stmt = insert(self.EmbeddingStore).values(data)
Expand Down Expand Up @@ -821,6 +855,7 @@ async def aadd_embeddings(
kwargs: vectorstore specific parameters
"""
await self.__apost_init__() # Lazy async init
_validate_lengths_match(texts, embeddings, metadatas, ids)

if ids is None:
ids_ = [str(uuid.uuid4()) for _ in texts]
Expand All @@ -843,7 +878,7 @@ async def aadd_embeddings(
"cmetadata": metadata or {},
}
for text, metadata, embedding, id in zip(
texts, metadatas, embeddings, ids_
texts, metadatas, embeddings, ids_, strict=True
)
]
stmt = insert(self.EmbeddingStore).values(data)
Expand Down
88 changes: 88 additions & 0 deletions tests/unit_tests/v1/test_pgvector_length_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Regression tests pinning that `PGVector.add_embeddings` /
`aadd_embeddings` raise `ValueError` when input lengths disagree.

Previously these methods built the SQL payload with
`zip(texts, metadatas, embeddings, ids_)`, which silently truncates to
the shortest argument and then returns the full `ids_` list — so an
upstream embedder bug that returned fewer embeddings than texts would
yield N IDs to the caller but only M rows in the database, with no
exception. See #300.

These tests bypass `PGVector.__init__` (no Postgres needed) by stubbing
the few attributes the validation path touches.
"""

from __future__ import annotations

import pytest

from langchain_postgres.vectorstores import PGVector


def _make_store() -> PGVector:
"""Build a barebones `PGVector` that's just enough for the early
length-validation block to execute."""
store = object.__new__(PGVector)
# Sync path: `assert not self._async_engine` runs before validation.
store._async_engine = False
return store


def test_add_embeddings_raises_when_embeddings_shorter_than_texts() -> None:
store = _make_store()
with pytest.raises(ValueError, match="3 texts but 1 embeddings"):
store.add_embeddings(
texts=["a", "b", "c"],
embeddings=[[0.1, 0.2]],
)


def test_add_embeddings_raises_when_metadatas_length_mismatches() -> None:
store = _make_store()
with pytest.raises(ValueError, match="2 texts but 3 metadatas"):
store.add_embeddings(
texts=["a", "b"],
embeddings=[[0.1], [0.2]],
metadatas=[{}, {}, {}],
)


def test_add_embeddings_raises_when_ids_length_mismatches() -> None:
store = _make_store()
with pytest.raises(ValueError, match="2 texts but 1 ids"):
store.add_embeddings(
texts=["a", "b"],
embeddings=[[0.1], [0.2]],
ids=["only-one"],
)


def test_add_embeddings_does_not_raise_when_lengths_match() -> None:
"""Validation must not fire when shapes line up. We don't reach the DB
layer here — the test uses an attribute access on a stubbed session
helper to detect that validation passed and execution continued."""
store = _make_store()
# Validation lives before any DB session is opened, so reaching the
# `with self._make_sync_session()` line should raise AttributeError on
# our stub. That is the signal we want — the ValueError path was NOT
# taken.
with pytest.raises(AttributeError):
store.add_embeddings(
texts=["a", "b"],
embeddings=[[0.1], [0.2]],
metadatas=[{}, {}],
ids=["x", "y"],
)


@pytest.mark.asyncio
async def test_aadd_embeddings_raises_when_embeddings_shorter_than_texts() -> None:
store = _make_store()
# `__apost_init__` short-circuits when `_async_init` is True, so the
# async validation block runs without touching the DB engine.
store._async_init = True
with pytest.raises(ValueError, match="3 texts but 2 embeddings"):
await store.aadd_embeddings(
texts=["a", "b", "c"],
embeddings=[[0.1], [0.2]],
)