Skip to content

Commit 6dec86c

Browse files
fix(community): Add INSERT support to PrismaVectorStore for ParentDocumentRetriever compatibility (#8833)
Previously, PrismaVectorStore only used UPDATE statements when adding vectors, which caused silent failures when used with ParentDocumentRetriever. The retriever creates new child documents that don't exist in the database, so UPDATE statements would succeed but not create any records. Changes: - Add new `addDocumentsWithVectors` method that uses INSERT statements to create records - Modify `addDocuments` to use `addDocumentsWithVectors` instead of `addVectors` - Maintain backward compatibility by keeping the original `addVectors` method unchanged - Add tests to verify the new behavior and ensure no regression This fix ensures PrismaVectorStore works correctly with ParentDocumentRetriever while maintaining compatibility with existing code. Fixes #8833
1 parent 47edf3f commit 6dec86c

File tree

2 files changed

+162
-1
lines changed

2 files changed

+162
-1
lines changed

libs/langchain-community/src/vectorstores/prisma.ts

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ export class PrismaVectorStore<
303303
*/
304304
async addDocuments(documents: Document<TModel>[]) {
305305
const texts = documents.map(({ pageContent }) => pageContent);
306-
return this.addVectors(
306+
return this.addDocumentsWithVectors(
307307
await this.embeddings.embedDocuments(texts),
308308
documents
309309
);
@@ -350,6 +350,58 @@ export class PrismaVectorStore<
350350
);
351351
}
352352

353+
/**
354+
* Adds documents with their corresponding vectors to the store using INSERT statements.
355+
* This method ensures documents are created if they don't exist, making it compatible
356+
* with ParentDocumentRetriever which creates new child documents.
357+
* @param vectors The vectors to add.
358+
* @param documents The documents associated with the vectors.
359+
* @returns A promise that resolves when the documents have been added.
360+
*/
361+
async addDocumentsWithVectors(
362+
vectors: number[][],
363+
documents: Document<TModel>[]
364+
) {
365+
// table name, column name cannot be parametrised
366+
// these fields are thus not escaped by Prisma and can be dangerous if user input is used
367+
const tableNameRaw = this.Prisma.raw(`"${this.tableName}"`);
368+
const vectorColumnRaw = this.Prisma.raw(`"${this.vectorColumnName}"`);
369+
370+
// Build column names for INSERT statement
371+
const columnNames = this.selectColumns.map((col) =>
372+
this.Prisma.raw(`"${col}"`)
373+
);
374+
const allColumns = [...columnNames, vectorColumnRaw];
375+
376+
await this.db.$transaction(
377+
vectors.map((vector, idx) => {
378+
const document = documents[idx];
379+
const vectorString = `[${vector.join(",")}]`;
380+
381+
// Build values for each column
382+
const columnValues = this.selectColumns.map((col) => {
383+
if (col === this.contentColumn) {
384+
return document.pageContent;
385+
}
386+
return document.metadata[col];
387+
});
388+
389+
// Add vector as the last value
390+
const allValues = [
391+
...columnValues,
392+
this.Prisma.sql`${vectorString}::vector`,
393+
];
394+
395+
return this.db.$executeRaw(
396+
this.Prisma.sql`
397+
INSERT INTO ${tableNameRaw} (${this.Prisma.join(allColumns, ", ")})
398+
VALUES (${this.Prisma.join(allValues, ", ")})
399+
`
400+
);
401+
})
402+
);
403+
}
404+
353405
/**
354406
* Performs a similarity search with the specified query.
355407
* @param query The query to use for the similarity search.

libs/langchain-community/src/vectorstores/tests/prisma.test.ts

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22
import { FakeEmbeddings } from "@langchain/core/utils/testing";
33
import { jest, test, expect } from "@jest/globals";
4+
import { Document } from "@langchain/core/documents";
45
import { PrismaVectorStore } from "../prisma.js";
56

67
class Sql {
@@ -39,6 +40,7 @@ describe("Prisma", () => {
3940
beforeEach(() => {
4041
jest.clearAllMocks();
4142
});
43+
4244
test("passes provided filters with simiaritySearch", async () => {
4345
const embeddings = new FakeEmbeddings();
4446
const store = new PrismaVectorStore(new FakeEmbeddings(), {
@@ -270,4 +272,111 @@ describe("Prisma", () => {
270272
expect(sqlCall).toBeDefined();
271273
});
272274
});
275+
276+
test("addDocumentsWithVectors creates new documents with INSERT", async () => {
277+
const embeddings = new FakeEmbeddings();
278+
const store = new PrismaVectorStore(embeddings, {
279+
db: mockPrismaClient,
280+
prisma: mockPrismaNamespace,
281+
tableName: "test",
282+
vectorColumnName: "vector",
283+
columns: mockColumns,
284+
});
285+
286+
const documents = [
287+
new Document({
288+
pageContent: "test content 1",
289+
metadata: { id: "doc1", custom: "value1" },
290+
}),
291+
new Document({
292+
pageContent: "test content 2",
293+
metadata: { id: "doc2", custom: "value2" },
294+
}),
295+
];
296+
297+
const vectors = [
298+
[1, 2, 3],
299+
[4, 5, 6],
300+
];
301+
302+
// Mock the transaction to capture the SQL statements
303+
$transaction.mockImplementation((queries) => {
304+
// Verify that INSERT statements are being used
305+
expect(queries).toHaveLength(2);
306+
return Promise.resolve();
307+
});
308+
309+
await store.addDocumentsWithVectors(vectors, documents);
310+
311+
expect($transaction).toHaveBeenCalledTimes(1);
312+
expect($executeRaw).toHaveBeenCalledTimes(2);
313+
});
314+
315+
test("addDocuments uses addDocumentsWithVectors instead of addVectors", async () => {
316+
const embeddings = new FakeEmbeddings();
317+
const store = new PrismaVectorStore(embeddings, {
318+
db: mockPrismaClient,
319+
prisma: mockPrismaNamespace,
320+
tableName: "test",
321+
vectorColumnName: "vector",
322+
columns: mockColumns,
323+
});
324+
325+
const documents = [
326+
new Document({
327+
pageContent: "test content",
328+
metadata: { id: "doc1" },
329+
}),
330+
];
331+
332+
// Spy on both methods
333+
const addDocumentsWithVectorsSpy = jest
334+
.spyOn(store, "addDocumentsWithVectors")
335+
.mockResolvedValue();
336+
const addVectorsSpy = jest.spyOn(store, "addVectors").mockResolvedValue();
337+
338+
await store.addDocuments(documents);
339+
340+
// Verify addDocumentsWithVectors was called
341+
expect(addDocumentsWithVectorsSpy).toHaveBeenCalledTimes(1);
342+
// Verify addVectors was NOT called
343+
expect(addVectorsSpy).not.toHaveBeenCalled();
344+
});
345+
346+
test("addVectors still uses UPDATE statements for backward compatibility", async () => {
347+
const embeddings = new FakeEmbeddings();
348+
const store = new PrismaVectorStore(embeddings, {
349+
db: mockPrismaClient,
350+
prisma: mockPrismaNamespace,
351+
tableName: "test",
352+
vectorColumnName: "vector",
353+
columns: mockColumns,
354+
});
355+
356+
const documents = [
357+
new Document({
358+
pageContent: "test content",
359+
metadata: { id: "doc1" },
360+
}),
361+
];
362+
363+
const vectors = [[1, 2, 3]];
364+
365+
// Mock sql function to capture the SQL template
366+
let capturedSql = "";
367+
// @ts-expect-error - we are mocking the sql function
368+
sql.mockImplementation((strings: string[], ...values) => {
369+
capturedSql = strings.join("");
370+
return { strings, values };
371+
});
372+
373+
$transaction.mockResolvedValue([]);
374+
375+
await store.addVectors(vectors, documents);
376+
377+
expect($transaction).toHaveBeenCalledTimes(1);
378+
// Verify UPDATE statement is used
379+
expect(capturedSql).toContain("UPDATE");
380+
expect(capturedSql).not.toContain("INSERT");
381+
});
273382
});

0 commit comments

Comments
 (0)