diff --git a/ui/__tests__/rag-context.test.ts b/ui/__tests__/rag-context.test.ts new file mode 100644 index 0000000..20da319 --- /dev/null +++ b/ui/__tests__/rag-context.test.ts @@ -0,0 +1,92 @@ +import { + DEFAULT_RAG_CONTEXT_CHAR_LIMIT, + formatRagSource, + parsePositiveInteger, + parseRagContextCharLimit, + prepareRagContext, +} from '@/utils/server/rag-context'; + +import { describe, expect, it } from 'vitest'; + +describe('RAG context preparation', () => { + it('formats retrieved sources with metadata and distance', () => { + expect( + formatRagSource( + ' The method uses a controlled cohort. ', + { title: 'Paper A', page: 4 }, + 2, + 0.12891, + ), + ).toBe( + 'Source 2) Title: Paper A, Page: 4, Distance: 0.1289, Content: The method uses a controlled cohort.\n', + ); + }); + + it('deduplicates repeated chunks before building the prompt context', () => { + const result = prepareRagContext({ + documents: [ + ['Same abstract text', 'Same abstract text', 'Different result'], + ], + metadatas: [ + [ + { title: 'Paper', page: 1 }, + { title: 'Paper', page: 1 }, + { title: 'Paper', page: 2 }, + ], + ], + distances: [[0.1, 0.2, 0.3]], + }); + + expect(result.sourceCount).toBe(2); + expect(result.omittedSourceCount).toBe(1); + expect(result.context).toContain('Source 1) Title: Paper, Page: 1'); + expect(result.context).toContain('Source 2) Title: Paper, Page: 2'); + }); + + it('keeps the generated context under the configured character budget', () => { + const result = prepareRagContext( + { + documents: [ + [ + 'A'.repeat(DEFAULT_RAG_CONTEXT_CHAR_LIMIT), + 'This second source should be omitted after the context is full.', + ], + ], + metadatas: [ + [ + { title: 'Long Paper', page: 10 }, + { title: 'Other Paper', page: 11 }, + ], + ], + }, + 260, + ); + + expect(result.context.length).toBeLessThanOrEqual(260); + expect(result.sourceCount).toBe(1); + expect(result.omittedSourceCount).toBe(1); + expect(result.context).toContain('...'); + }); + + it('returns an explicit empty-context message when no documents match', () => { + const result = prepareRagContext({ documents: [[]], metadatas: [[]] }); + + expect(result.sourceCount).toBe(0); + expect(result.context).toContain('No matching uploaded-document context'); + }); + + it('parses bounded positive integer configuration values', () => { + expect(parsePositiveInteger('12', 8, 20)).toBe(12); + expect(parsePositiveInteger(100, 8, 20)).toBe(20); + expect(parsePositiveInteger('bad', 8, 20)).toBe(8); + expect(parsePositiveInteger(0, 8, 20)).toBe(8); + }); + + it('parses the configurable RAG context character budget', () => { + expect(parseRagContextCharLimit('16000')).toBe(16000); + expect(parseRagContextCharLimit('bad')).toBe( + DEFAULT_RAG_CONTEXT_CHAR_LIMIT, + ); + expect(parseRagContextCharLimit(100000)).toBe(50000); + }); +}); diff --git a/ui/pages/api/fetch-documents.ts b/ui/pages/api/fetch-documents.ts index 9304e48..5137917 100644 --- a/ui/pages/api/fetch-documents.ts +++ b/ui/pages/api/fetch-documents.ts @@ -1,25 +1,64 @@ -import type { NextApiRequest, NextApiResponse } from "next"; -import { ChromaClient, TransformersEmbeddingFunction } from "chromadb"; +import type { NextApiRequest, NextApiResponse } from 'next'; -export default async function handler(req: NextApiRequest, res: NextApiResponse) { +import { + DEFAULT_RAG_RETRIEVAL_RESULTS, + parsePositiveInteger, + parseRagContextCharLimit, + prepareRagContext, +} from '@/utils/server/rag-context'; + +import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb'; + +export default async function handler( + req: NextApiRequest, + res: NextApiResponse, +) { try { + if (req.method !== 'POST') { + return res.status(405).end(); + } + const client = new ChromaClient({ - path: "http://chroma-server:8000", + path: process.env.CHROMA_PATH || 'http://chroma-server:8000', }); const query = req.body.input; + if (typeof query !== 'string' || query.trim().length === 0) { + return res + .status(400) + .json({ error: 'A non-empty input query is required' }); + } + + const nResults = parsePositiveInteger( + req.body.nResults, + DEFAULT_RAG_RETRIEVAL_RESULTS, + 20, + ); + const contextCharLimit = parseRagContextCharLimit( + req.body.contextCharLimit ?? process.env.RAG_CONTEXT_CHAR_LIMIT, + ); + const embedder = new TransformersEmbeddingFunction(); - const collection = await client.getOrCreateCollection({ name: "default-collection", embeddingFunction: embedder }); + const collection = await client.getOrCreateCollection({ + name: 'default-collection', + embeddingFunction: embedder, + }); - // query the collection - const results = await collection.query({ - nResults: 4, - queryTexts: [query] - }) + // query the collection + const results = await collection.query({ + nResults, + queryTexts: [query.trim()], + }); - res.status(200).json(results); + const preparedContext = prepareRagContext(results, contextCharLimit); + + res.status(200).json({ + ...results, + _prepared: preparedContext, + ...preparedContext, + }); } catch (error) { if (error instanceof Error) { console.error('Error message:', error.message); @@ -29,4 +68,4 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) } res.status(500).json({ error: 'An unexpected error occurred :(' }); } -} \ No newline at end of file +} diff --git a/ui/pages/api/rag-chat.ts b/ui/pages/api/rag-chat.ts index ce84d67..0c35d0f 100644 --- a/ui/pages/api/rag-chat.ts +++ b/ui/pages/api/rag-chat.ts @@ -1,6 +1,6 @@ import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const'; import { OpenAIError, OpenAIStream } from '@/utils/server'; -import { codeBlock, oneLine } from 'common-tags' +import { prepareRagContext } from '@/utils/server/rag-context'; import { ChatBody, Message } from '@/types/chat'; @@ -9,46 +9,48 @@ import wasm from '../../node_modules/@dqbd/tiktoken/lite/tiktoken_bg.wasm?module import tiktokenModel from '@dqbd/tiktoken/encoders/cl100k_base.json'; import { Tiktoken, init } from '@dqbd/tiktoken/lite/init'; +import { codeBlock, oneLine } from 'common-tags'; export const config = { runtime: 'edge', }; // Function to fetch and format documents -async function fetchAndFormatDocuments(lastMessageContent: string) { +async function fetchAndFormatDocuments( + lastMessageContent: string, + requestOrigin: string, +) { try { - console.log("fetching documents") - const response = await fetch('http://localhost:3000/api/fetch-documents', { + console.log('fetching documents'); + const fetchDocumentsUrl = new URL('/api/fetch-documents', requestOrigin); + const response = await fetch(fetchDocumentsUrl, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ input: lastMessageContent }), }); - + if (!response.ok) { throw new Error(`Error fetching documents: ${response.statusText}`); } const data = await response.json(); - const result = data.metadatas[0].map((metadata: any, index: number) => { - return `Source ${index + 1}) Title: ${metadata.title}, Page: ${metadata.page}, Content: ${data.documents[0][index]}\n`; - }).join(''); + const result = + typeof data._prepared?.context === 'string' + ? data._prepared.context + : typeof data.context === 'string' + ? data.context + : prepareRagContext(data).context; console.log(result); return result; - } catch (error) { console.error('Error fetching and formatting documents:', error); - throw error; // You may want to throw a more specific error object here + return 'No matching uploaded-document context was found for this question.'; } } - - - - const handler = async (req: Request): Promise => { - try { const { model, messages, key, prompt, temperature } = (await req.json()) as ChatBody; @@ -85,8 +87,11 @@ const handler = async (req: Request): Promise => { const lastMessage = messages[messages.length - 1]; - const relevantDocuments = await fetchAndFormatDocuments(lastMessage.content); - + const relevantDocuments = await fetchAndFormatDocuments( + lastMessage.content, + new URL(req.url).origin, + ); + let temperatureToUse = temperature; if (temperatureToUse == null) { temperatureToUse = DEFAULT_TEMPERATURE; @@ -97,22 +102,20 @@ const handler = async (req: Request): Promise => { let tokenCount = prompt_tokens.length; let messagesToSend: Message[] = []; - encoding.free(); console.log(model, promptToSend, temperatureToUse, key, messagesToSend); - - messagesToSend = [ + messagesToSend = [ { - role: "user", + role: 'user', content: codeBlock` Here is the relevant documentation: ${relevantDocuments} `, }, { - role: "user", + role: 'user', content: codeBlock` ${oneLine` Answer my next question using only the above documentation. @@ -135,14 +138,13 @@ const handler = async (req: Request): Promise => { `, }, { - role: "user", + role: 'user', content: codeBlock` Here is my question: ${oneLine`${lastMessage.content}`} `, }, - ] - + ]; const stream = await OpenAIStream( model, diff --git a/ui/utils/server/rag-context.ts b/ui/utils/server/rag-context.ts new file mode 100644 index 0000000..cc8dc4d --- /dev/null +++ b/ui/utils/server/rag-context.ts @@ -0,0 +1,168 @@ +type RagMetadata = { + title?: string; + page?: number | string; + source?: string; +}; + +type ChromaQueryResult = { + documents?: Array>; + metadatas?: Array>; + distances?: Array> | null; +}; + +export type PreparedRagContext = { + context: string; + sourceCount: number; + omittedSourceCount: number; +}; + +export const DEFAULT_RAG_RETRIEVAL_RESULTS = 8; +export const DEFAULT_RAG_CONTEXT_CHAR_LIMIT = 12000; +export const MAX_RAG_CONTEXT_CHAR_LIMIT = 50000; + +const normalizeWhitespace = (value: string) => + value.replace(/\s+/g, ' ').trim(); + +const normalizeForDedupe = (value: string) => + normalizeWhitespace(value).toLowerCase(); + +const truncateAtWordBoundary = (value: string, maxLength: number) => { + if (value.length <= maxLength) { + return value; + } + + const suffix = '...'; + const sliced = value.slice(0, Math.max(0, maxLength - suffix.length)); + const lastSpace = sliced.lastIndexOf(' '); + + if (lastSpace < 80) { + return `${sliced.trim()}${suffix}`; + } + + return `${sliced.slice(0, lastSpace).trim()}${suffix}`; +}; + +export const parsePositiveInteger = ( + value: unknown, + fallback: number, + max: number, +) => { + const parsed = typeof value === 'number' ? value : Number(value); + + if (!Number.isFinite(parsed) || parsed < 1) { + return fallback; + } + + return Math.min(Math.floor(parsed), max); +}; + +export const parseRagContextCharLimit = (value: unknown) => + parsePositiveInteger( + value, + DEFAULT_RAG_CONTEXT_CHAR_LIMIT, + MAX_RAG_CONTEXT_CHAR_LIMIT, + ); + +export const formatRagSource = ( + document: string, + metadata: RagMetadata | null | undefined, + index: number, + distance?: number | null, +) => { + const title = normalizeWhitespace( + metadata?.title || metadata?.source || 'Untitled source', + ); + const page = metadata?.page == null ? 'unknown' : String(metadata.page); + const content = normalizeWhitespace(document); + const distanceLabel = + typeof distance === 'number' && Number.isFinite(distance) + ? `, Distance: ${distance.toFixed(4)}` + : ''; + + return `Source ${index}) Title: ${title}, Page: ${page}${distanceLabel}, Content: ${content}\n`; +}; + +export const prepareRagContext = ( + result: ChromaQueryResult, + maxCharacters = DEFAULT_RAG_CONTEXT_CHAR_LIMIT, +): PreparedRagContext => { + const documents = result.documents?.[0] ?? []; + const metadatas = result.metadatas?.[0] ?? []; + const distances = result.distances?.[0] ?? []; + const seen = new Set(); + const sources: string[] = []; + let usedCharacters = 0; + let omittedSourceCount = 0; + + documents.forEach((document, resultIndex) => { + if (!document) { + return; + } + + const normalizedDocument = normalizeForDedupe(document); + + if (!normalizedDocument || seen.has(normalizedDocument)) { + omittedSourceCount += 1; + return; + } + + seen.add(normalizedDocument); + + const remainingCharacters = maxCharacters - usedCharacters; + + if (remainingCharacters <= 0) { + omittedSourceCount += 1; + return; + } + + const formatted = formatRagSource( + document, + metadatas[resultIndex], + sources.length + 1, + distances[resultIndex], + ); + + if (formatted.length > remainingCharacters) { + const metadata = metadatas[resultIndex]; + const title = normalizeWhitespace( + metadata?.title || metadata?.source || 'Untitled source', + ); + const page = metadata?.page == null ? 'unknown' : String(metadata.page); + const prefix = `Source ${ + sources.length + 1 + }) Title: ${title}, Page: ${page}, Content: `; + const contentBudget = remainingCharacters - prefix.length - 1; + + if (contentBudget > 120) { + const truncated = `${prefix}${truncateAtWordBoundary( + normalizeWhitespace(document), + contentBudget, + )}\n`; + sources.push(truncated); + usedCharacters += truncated.length; + } else { + omittedSourceCount += 1; + } + + return; + } + + sources.push(formatted); + usedCharacters += formatted.length; + }); + + if (sources.length === 0) { + return { + context: + 'No matching uploaded-document context was found for this question.', + sourceCount: 0, + omittedSourceCount, + }; + } + + return { + context: sources.join(''), + sourceCount: sources.length, + omittedSourceCount, + }; +};