diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 1cebdd501..022d01748 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -103,7 +103,7 @@ Some providers support multiple API formats (OpenAI chat, Anthropic messages, em |----------|----------| | `chat` | OpenAI-compatible chat completions | | `messages` | Anthropic Claude Messages API | -| `embeddings` | OpenAI-compatible embeddings | +| `embeddings` | OpenAI-compatible embeddings (Gemini providers auto-transformed) | | `image` | Image generation (DALL-E, etc.) | | `transcriptions` | Speech-to-text (Whisper) | | `speech` | Text-to-speech | diff --git a/docs/openapi/components/schemas/AliasConfig.yaml b/docs/openapi/components/schemas/AliasConfig.yaml index d38b0e1db..cb72c4911 100644 --- a/docs/openapi/components/schemas/AliasConfig.yaml +++ b/docs/openapi/components/schemas/AliasConfig.yaml @@ -67,7 +67,8 @@ properties: - **text** — Text generation (LLM). Supports all text wire protocols: chat completions, messages, gemini, responses, ollama. - - **embeddings** — Vector embeddings. + - **embeddings** — Vector embeddings, including Gemini backends via + automatic request/response transformation. - **transcriptions** — Audio to text. - **speech** — Text to audio. - **image** — Image generation/editing. diff --git a/docs/openapi/paths/v1_embeddings.yaml b/docs/openapi/paths/v1_embeddings.yaml index da19fb68b..c3c677e21 100644 --- a/docs/openapi/paths/v1_embeddings.yaml +++ b/docs/openapi/paths/v1_embeddings.yaml @@ -1,10 +1,16 @@ post: tags: - Inference — Embeddings - summary: OpenAI-compatible embeddings (always pass-through) + summary: OpenAI-compatible embeddings with provider-aware transformation description: | - Requires a model configured with `type: embeddings`. The request and - response are forwarded verbatim. + Requires a model configured with `type: embeddings`. Accepts requests in + OpenAI embeddings format. For OpenAI-compatible providers, requests and + responses are forwarded directly. For Gemini providers, Plexus transforms + the request to the Gemini `embedContent`/`batchEmbedContents` format and + normalises the response back to OpenAI format. + + Supported provider types: `openai`, `chat` (pass-through), `gemini` + (transformed). requestBody: required: true content: diff --git a/packages/backend/src/routes/inference/embeddings.ts b/packages/backend/src/routes/inference/embeddings.ts index cee1765b7..daf067f48 100644 --- a/packages/backend/src/routes/inference/embeddings.ts +++ b/packages/backend/src/routes/inference/embeddings.ts @@ -1,7 +1,8 @@ import { FastifyInstance } from 'fastify'; import { logger } from '../../utils/logger'; import { Dispatcher } from '../../services/dispatcher'; -import { EmbeddingsTransformer } from '../../transformers'; +import { OpenAIEmbeddingsTransformer } from '../../transformers/embeddings'; +import { UnifiedEmbeddingsRequest } from '../../types/unified'; import { UsageStorageService } from '../../services/usage-storage'; import { UsageRecord } from '../../types/usage'; import { getClientIp } from '../../utils/ip'; @@ -54,11 +55,17 @@ export async function registerEmbeddingsRoute( logger.silly('Incoming Embeddings Request', body); - const transformer = new EmbeddingsTransformer(); - let unifiedRequest = await transformer.parseRequest(body); - unifiedRequest.incomingApiType = 'embeddings'; - unifiedRequest.originalBody = body; - unifiedRequest.requestId = requestId; + const transformer = new OpenAIEmbeddingsTransformer(); + let unifiedRequest: UnifiedEmbeddingsRequest = { + model: body.model, + input: body.input, + encoding_format: body.encoding_format, + dimensions: body.dimensions, + user: body.user, + incomingApiType: 'embeddings', + originalBody: body, + requestId, + }; unifiedRequest = attachKeyAccessPolicy(request, unifiedRequest); DebugManager.getInstance().startLog(requestId, body, sanitizeHeaders(request.headers as any)); @@ -78,6 +85,8 @@ export async function registerEmbeddingsRoute( usageRecord.selectedModelName = unifiedResponse.plexus?.model; usageRecord.canonicalModelName = unifiedResponse.plexus?.canonicalModel; usageRecord.outgoingApiType = unifiedResponse.plexus?.apiType; + usageRecord.attemptCount = unifiedResponse.plexus?.attemptCount ?? 1; + usageRecord.retryHistory = unifiedResponse.plexus?.retryHistory ?? null; usageRecord.isPassthrough = true; // Embeddings are always pass-through (OpenAI format) usageRecord.tokensInput = unifiedResponse.usage?.prompt_tokens ?? 0; usageRecord.tokensOutput = 0; // Embeddings don't have output tokens diff --git a/packages/backend/src/services/__tests__/embeddings-transformer-factory.test.ts b/packages/backend/src/services/__tests__/embeddings-transformer-factory.test.ts new file mode 100644 index 000000000..af61fea11 --- /dev/null +++ b/packages/backend/src/services/__tests__/embeddings-transformer-factory.test.ts @@ -0,0 +1,51 @@ +import { test, expect, describe } from 'vitest'; +import { EmbeddingsTransformerFactory } from '../../services/embeddings-transformer-factory'; + +describe('EmbeddingsTransformerFactory', () => { + test('should return GeminiEmbeddingsTransformer for gemini type', () => { + const t = EmbeddingsTransformerFactory.getTransformer('gemini'); + expect(t.name).toBe('gemini'); + }); + + test('should return OpenAIEmbeddingsTransformer for openai type', () => { + const t = EmbeddingsTransformerFactory.getTransformer('openai'); + expect(t.name).toBe('openai'); + }); + + test('should return OpenAIEmbeddingsTransformer for chat type', () => { + const t = EmbeddingsTransformerFactory.getTransformer('chat'); + expect(t.name).toBe('openai'); + }); + + test('should default to OpenAI for unknown type', () => { + const t = EmbeddingsTransformerFactory.getTransformer('unknown'); + expect(t.name).toBe('openai'); + }); +}); + +describe('resolveTransformer', () => { + test('should resolve Gemini transformer when gemini is in provider types', () => { + const t = EmbeddingsTransformerFactory.resolveTransformer(['chat', 'gemini']); + expect(t.name).toBe('gemini'); + }); + + test('should fall back to OpenAI when no dedicated type matches', () => { + const t = EmbeddingsTransformerFactory.resolveTransformer(['chat', 'openai']); + expect(t.name).toBe('openai'); + }); + + test('should fall back to OpenAI for unknown provider types', () => { + const t = EmbeddingsTransformerFactory.resolveTransformer(['anthropic']); + expect(t.name).toBe('openai'); + }); + + test('should fall back to OpenAI for empty provider types', () => { + const t = EmbeddingsTransformerFactory.resolveTransformer([]); + expect(t.name).toBe('openai'); + }); + + test('should resolve Gemini when gemini appears alongside other types', () => { + const t = EmbeddingsTransformerFactory.resolveTransformer(['ollama', 'gemini', 'chat']); + expect(t.name).toBe('gemini'); + }); +}); diff --git a/packages/backend/src/services/dispatcher.ts b/packages/backend/src/services/dispatcher.ts index c2dfdd21a..b8cf1b8ea 100644 --- a/packages/backend/src/services/dispatcher.ts +++ b/packages/backend/src/services/dispatcher.ts @@ -23,6 +23,7 @@ import { UsageStorageService } from './usage-storage'; import { CooldownParserRegistry } from './cooldown-parsers'; import { getConfig, getProviderTypes } from '../config'; import { applyModelBehaviors } from './model-behaviors'; +import { EmbeddingsTransformerFactory } from './embeddings-transformer-factory'; import { resolveAdapters } from './adapter-resolver'; import type { ResolvedAdapter } from '../types/provider-adapter'; import { getModels } from '@earendil-works/pi-ai'; @@ -62,12 +63,25 @@ interface RetryHistoryLikeEntry { type ResolveTimeoutMs = (timeoutMs?: number | null) => number; +/** + * Request-level API types (e.g. embeddings, transcriptions) share base URLs + * with their provider-level counterparts (e.g. chat, gemini). This map defines + * which provider-level URL keys to try when no exact or default URL is configured. + */ +const API_TYPE_ALIASES: Record = { + embeddings: ['chat', 'gemini'], + transcriptions: ['chat', 'gemini'], + speech: ['chat', 'gemini'], + images: ['chat', 'gemini'], +}; + /** * Strips trailing /v1beta* path segments from Gemini base URLs. * Gemini's transformer adds /v1beta to the path, so we need to ensure * the base URL doesn't include it to avoid duplication like /v1beta/v1beta/... * Only strips beta versions (e.g. /v1beta, /v1beta1) — plain /v1 is valid for other APIs. */ + function stripTrailingApiVersion(url: string): string { return url.replace(/\/(v\d+beta\d*)$/i, ''); } @@ -1684,26 +1698,33 @@ export class Dispatcher { rawBaseUrl = defaultUrl; logger.debug(`Dispatcher: Using default base URL.`); } else { - // If we can't find a specific URL for this type, and no default, fall back to the first one? - // Or throw error. - const firstKey = Object.keys(urlMap)[0]; - - if (firstKey) { - const firstUrl = urlMap[firstKey]; - if (firstUrl) { - rawBaseUrl = firstUrl; - logger.warn( - `No specific base URL found for api type '${targetApiType}'. using '${firstKey}' as fallback.` - ); + // Resolve via API_TYPE_ALIASES before falling back to the first key. + const aliases = API_TYPE_ALIASES[typeKey]; + const aliasKey = aliases?.find((a) => urlMap[a]); + + if (aliasKey) { + rawBaseUrl = urlMap[aliasKey]!; + logger.debug(`Dispatcher: Using '${aliasKey}' base URL for api type '${targetApiType}'.`); + } else { + const firstKey = Object.keys(urlMap)[0]; + + if (firstKey) { + const firstUrl = urlMap[firstKey]; + if (firstUrl) { + rawBaseUrl = firstUrl; + logger.warn( + `No specific base URL found for api type '${targetApiType}'. using '${firstKey}' as fallback.` + ); + } else { + throw new Error( + `No base URL configured for api type '${targetApiType}' and no default found.` + ); + } } else { throw new Error( `No base URL configured for api type '${targetApiType}' and no default found.` ); } - } else { - throw new Error( - `No base URL configured for api type '${targetApiType}' and no default found.` - ); } } } @@ -3158,10 +3179,10 @@ export class Dispatcher { /** * Dispatch embeddings request to provider - * Simplified version of dispatch() since embeddings: - * - Don't support streaming - * - Use universal API format (no transformation needed) - * - Always use /embeddings endpoint + * Uses EmbeddingsTransformerFactory for provider-type-aware: + * - URL construction (e.g. Gemini /v1beta/models/{model}:embedContent) + * - Auth headers (e.g. x-goog-api-key for Gemini) + * - Request/response transformation */ async dispatchEmbeddings(request: any): Promise { const config = getConfig(); @@ -3230,32 +3251,40 @@ export class Dispatcher { this.emitRoutingUpdate(request.requestId, route); try { + const providerTypes = getProviderTypes(route.config); + const transformer = EmbeddingsTransformerFactory.resolveTransformer(providerTypes); + const requestWithModel = { ...request, model: route.model }; + const baseUrl = this.resolveBaseUrl(route, 'embeddings'); - const url = `${baseUrl}/embeddings`; + const endpoint = transformer.getEndpoint + ? transformer.getEndpoint(requestWithModel) + : transformer.defaultEndpoint; + const url = `${baseUrl}${endpoint}`; const headers: Record = { 'Content-Type': 'application/json', Accept: 'application/json', }; - if (route.config.api_key) { - headers['Authorization'] = `Bearer ${route.config.api_key}`; + if (transformer.getAuthHeaders) { + transformer.getAuthHeaders(route.config.api_key, headers); + } else { + headers['Authorization'] = `Bearer ${route.config.api_key}`; + } } - if (route.config.headers) { Object.assign(headers, route.config.headers); } - const payload = { - ...request.originalBody, - model: route.model, - }; - + let payload = await transformer.transformRequest(requestWithModel); if (route.config.extraBody) { Object.assign(payload, route.config.extraBody); } - - // Merge alias-level extraBody (overrides provider level) + // Merge model-level extraBody (overrides provider level) + if (route.modelConfig?.extraBody) { + Object.assign(payload, route.modelConfig.extraBody); + } + // Merge alias-level extraBody (overrides provider and model level) if (route.canonicalModel) { const aliasConfig = getConfig().models?.[route.canonicalModel]; if (aliasConfig?.extraBody) { @@ -3326,19 +3355,28 @@ export class Dispatcher { } } - const responseBody = await response.json(); - logger.silly('Embeddings Response Payload', responseBody); + const rawResponseBody = await this.parseJsonResponseBody( + response, + request.requestId, + route, + 'embeddings' + ); + logger.silly('Embeddings Response Payload', rawResponseBody); if (request.requestId) { - DebugManager.getInstance().addRawResponse(request.requestId, responseBody); + DebugManager.getInstance().addRawResponse(request.requestId, rawResponseBody); } - + const transformedResponse = await transformer.transformResponse( + rawResponseBody, + requestWithModel + ); const enrichedResponse: any = { - ...responseBody, + ...transformedResponse, plexus: { provider: route.provider, model: route.model, apiType: 'embeddings', + isPassthrough: true, pricing: route.modelConfig?.pricing, providerDiscount: route.config.discount, canonicalModel: route.canonicalModel, @@ -3347,6 +3385,7 @@ export class Dispatcher { }; await this.recordAttemptMetric(route, request.requestId, true); + CooldownManager.getInstance().markProviderSuccess(route.provider, route.model); this.appendSuccessAttempt(retryHistory, route, 'embeddings'); this.attachAttemptMetadata( enrichedResponse, diff --git a/packages/backend/src/services/embeddings-transformer-factory.ts b/packages/backend/src/services/embeddings-transformer-factory.ts new file mode 100644 index 000000000..deef694cd --- /dev/null +++ b/packages/backend/src/services/embeddings-transformer-factory.ts @@ -0,0 +1,45 @@ +import { logger } from '../utils/logger'; +import { EmbeddingsTransformer } from '../types/embeddings-transformer'; +import { OpenAIEmbeddingsTransformer } from '../transformers/embeddings/openai'; +import { GeminiEmbeddingsTransformer } from '../transformers/embeddings/gemini'; + +/** + * EmbeddingsTransformerFactory + * + * Factory for retrieving the correct embeddings transformer based on the provider's API type. + * Supports 'gemini' (Google) and 'openai'/'chat' (OpenAI-compatible). Unknown types default to OpenAI format. + */ +export class EmbeddingsTransformerFactory { + /** + * Provider types with dedicated (non-OpenAI) embeddings transformers, + * in priority order. Used by resolveTransformer to pick the best match. + */ + static readonly DEDICATED_TYPES = ['gemini'] as const; + + /** + * Resolve the best embeddings transformer for a provider based on its type list. + * Checks dedicated types first (in priority order), then falls back to OpenAI format. + */ + static resolveTransformer(providerTypes: string[]): EmbeddingsTransformer { + const dedicated = providerTypes.find((t) => + (this.DEDICATED_TYPES as readonly string[]).includes(t.toLowerCase()) + ); + return this.getTransformer(dedicated ?? 'openai'); + } + + static getTransformer(providerType: string): EmbeddingsTransformer { + switch (providerType.toLowerCase()) { + case 'gemini': + return new GeminiEmbeddingsTransformer(); + case 'openai': + case 'chat': + default: + if (!['openai', 'chat'].includes(providerType.toLowerCase())) { + logger.warn( + `Unknown embeddings provider type '${providerType}', defaulting to OpenAI format` + ); + } + return new OpenAIEmbeddingsTransformer(); + } + } +} diff --git a/packages/backend/src/services/probe-service.ts b/packages/backend/src/services/probe-service.ts index e534c4eb8..e152dd556 100644 --- a/packages/backend/src/services/probe-service.ts +++ b/packages/backend/src/services/probe-service.ts @@ -170,8 +170,10 @@ export class ProbeService { response = await this.dispatcher.dispatch(unifiedRequest as any); } else if (apiType === 'embeddings') { + const embReq = testRequest as { model: string; input: string | string[] }; response = await this.dispatcher.dispatchEmbeddings({ model: directModelPath, + input: embReq.input, originalBody: testRequest, requestId, incomingApiType: 'embeddings', diff --git a/packages/backend/src/services/router.ts b/packages/backend/src/services/router.ts index ba9ab51fb..6e8f87374 100644 --- a/packages/backend/src/services/router.ts +++ b/packages/backend/src/services/router.ts @@ -175,7 +175,8 @@ async function filterGroupTargets( } if (alias.type === 'embeddings') return true; - return getProviderTypes(providerConfig).includes('embeddings'); + const providerTypes = getProviderTypes(providerConfig); + return providerTypes.includes('embeddings') || providerTypes.includes('gemini'); }); if (embeddingsTargets.length > 0) { diff --git a/packages/backend/src/transformers/__tests__/embeddings-gemini.test.ts b/packages/backend/src/transformers/__tests__/embeddings-gemini.test.ts new file mode 100644 index 000000000..6c58d06bc --- /dev/null +++ b/packages/backend/src/transformers/__tests__/embeddings-gemini.test.ts @@ -0,0 +1,186 @@ +import { test, expect, describe } from 'vitest'; +import { GeminiEmbeddingsTransformer } from '../embeddings/gemini'; + +describe('GeminiEmbeddingsTransformer', () => { + const transformer = new GeminiEmbeddingsTransformer(); + + describe('getEndpoint', () => { + test('should return embedContent for single string input', () => { + const request = { model: 'gemini-embedding-2', input: 'Hello' }; + expect(transformer.getEndpoint!(request as any)).toBe( + '/v1beta/models/gemini-embedding-2:embedContent' + ); + }); + + test('should return batchEmbedContents for array input with >1 items', () => { + const request = { model: 'gemini-embedding-2', input: ['A', 'B'] }; + expect(transformer.getEndpoint!(request as any)).toBe( + '/v1beta/models/gemini-embedding-2:batchEmbedContents' + ); + }); + + test('should prepend models/ prefix if missing', () => { + const request = { model: 'text-embedding-004', input: 'test' }; + expect(transformer.getEndpoint!(request as any)).toContain('models/text-embedding-004'); + }); + + test('should not prepend models/ if already present', () => { + const request = { model: 'models/gemini-embedding-2', input: 'test' }; + expect(transformer.getEndpoint!(request as any)).toContain('models/gemini-embedding-2'); + }); + + test('should not prepend models/ for tunedModels/ prefix', () => { + const request = { model: 'tunedModels/my-model', input: 'test' }; + expect(transformer.getEndpoint!(request as any)).toContain('tunedModels/my-model'); + }); + }); + + describe('getAuthHeaders', () => { + test('should set x-goog-api-key header', () => { + const headers: Record = {}; + transformer.getAuthHeaders!('my-key', headers); + expect(headers['x-goog-api-key']).toBe('my-key'); + }); + }); + + describe('transformRequest', () => { + test('should convert single string input to Gemini format', async () => { + const request = { + model: 'gemini-embedding-2', + input: 'Hello world', + originalBody: {}, + }; + const result = await transformer.transformRequest(request as any); + expect(result.model).toBe('models/gemini-embedding-2'); + expect(result.content.parts[0].text).toBe('Hello world'); + }); + + test('should convert batch input to batchEmbedContents format', async () => { + const request = { + model: 'gemini-embedding-2', + input: ['A', 'B', 'C'], + originalBody: {}, + }; + const result = await transformer.transformRequest(request as any); + expect(result.requests).toHaveLength(3); + expect(result.requests[0].model).toBe('models/gemini-embedding-2'); + expect(result.requests[0].content.parts[0].text).toBe('A'); + }); + + test('should pass through Gemini-specific options for every batch item', async () => { + const request = { + model: 'gemini-embedding-2', + input: ['A', 'B'], + originalBody: { + taskType: 'RETRIEVAL_QUERY', + title: 'Batch Title', + outputDimensionality: 128, + }, + dimensions: 256, + }; + const result = await transformer.transformRequest(request as any); + + expect(result.requests).toHaveLength(2); + expect(result.requests[0]).toMatchObject({ + model: 'models/gemini-embedding-2', + taskType: 'RETRIEVAL_QUERY', + title: 'Batch Title', + outputDimensionality: 128, + }); + expect(result.requests[1]).toMatchObject({ + model: 'models/gemini-embedding-2', + taskType: 'RETRIEVAL_QUERY', + title: 'Batch Title', + outputDimensionality: 128, + }); + }); + + test('should pass through taskType and title from originalBody', async () => { + const request = { + model: 'gemini-embedding-2', + input: 'Hello', + originalBody: { taskType: 'RETRIEVAL_QUERY', title: 'My Doc' }, + dimensions: undefined, + }; + const result = await transformer.transformRequest(request as any); + expect(result.taskType).toBe('RETRIEVAL_QUERY'); + expect(result.title).toBe('My Doc'); + }); + + test('should pass through dimensions as outputDimensionality', async () => { + const request = { + model: 'gemini-embedding-2', + input: 'Hello', + originalBody: {}, + dimensions: 256, + }; + const result = await transformer.transformRequest(request as any); + expect(result.outputDimensionality).toBe(256); + }); + + test('should prefer native outputDimensionality from originalBody when present', async () => { + const request = { + model: 'gemini-embedding-2', + input: 'Hello', + originalBody: { outputDimensionality: 128 }, + dimensions: 256, + }; + const result = await transformer.transformRequest(request as any); + expect(result.outputDimensionality).toBe(128); + }); + + test('should reject empty input arrays', async () => { + await expect( + transformer.transformRequest({ + model: 'gemini-embedding-2', + input: [], + originalBody: {}, + } as any) + ).rejects.toThrow('Gemini embeddings input array must contain at least one item'); + }); + }); + + describe('transformResponse', () => { + test('should transform single EmbedContentResponse', async () => { + const response = { + embedding: { values: [0.1, 0.2, 0.3] }, + usageMetadata: { promptTokenCount: 5, totalTokenCount: 5 }, + }; + const result = await transformer.transformResponse(response); + expect(result.object).toBe('list'); + expect(result.data).toHaveLength(1); + expect(result.data[0]!.embedding).toEqual([0.1, 0.2, 0.3]); + expect(result.usage!.prompt_tokens).toBe(5); + }); + + test('should transform BatchEmbedContentsResponse', async () => { + const response = { + embeddings: [{ values: [0.1] }, { values: [0.2] }], + usageMetadata: { promptTokenCount: 10, totalTokenCount: 10 }, + }; + const result = await transformer.transformResponse(response); + expect(result.data).toHaveLength(2); + expect(result.data[0]!.embedding).toEqual([0.1]); + expect(result.data[1]!.embedding).toEqual([0.2]); + expect(result.usage!.prompt_tokens).toBe(10); + }); + + test('should handle missing usageMetadata gracefully', async () => { + const response = { + embedding: { values: [0.1] }, + }; + const result = await transformer.transformResponse(response); + expect(result.usage).toBeUndefined(); + }); + }); + + describe('properties', () => { + test('should have correct name', () => { + expect(transformer.name).toBe('gemini'); + }); + + test('should have correct default endpoint', () => { + expect(transformer.defaultEndpoint).toBe('/v1beta/models/:model:embedContent'); + }); + }); +}); diff --git a/packages/backend/src/transformers/__tests__/embeddings.test.ts b/packages/backend/src/transformers/__tests__/embeddings.test.ts index c449a2c7a..c5727240a 100644 --- a/packages/backend/src/transformers/__tests__/embeddings.test.ts +++ b/packages/backend/src/transformers/__tests__/embeddings.test.ts @@ -1,75 +1,41 @@ import { test, expect, describe } from 'vitest'; -import { EmbeddingsTransformer } from '../embeddings'; +import { OpenAIEmbeddingsTransformer } from '../embeddings/openai'; -describe('EmbeddingsTransformer', () => { - const transformer = new EmbeddingsTransformer(); - - describe('parseRequest', () => { - test('should parse single text input', async () => { - const input = { - model: 'text-embedding-3-small', - input: 'Hello world', - }; - - const result = await transformer.parseRequest(input); - - expect(result.model).toBe('text-embedding-3-small'); - expect(result.input).toBe('Hello world'); - }); - - test('should parse array of texts', async () => { - const input = { - model: 'text-embedding-3-small', - input: ['Text 1', 'Text 2', 'Text 3'], - }; - - const result = await transformer.parseRequest(input); - - expect(result.model).toBe('text-embedding-3-small'); - expect(result.input).toEqual(['Text 1', 'Text 2', 'Text 3']); - }); - - test('should parse optional parameters', async () => { - const input = { - model: 'text-embedding-3-large', - input: 'Test', - encoding_format: 'float' as const, - dimensions: 256, - user: 'user-123', - }; - - const result = await transformer.parseRequest(input); - - expect(result.encoding_format).toBe('float'); - expect(result.dimensions).toBe(256); - expect(result.user).toBe('user-123'); - }); - }); +describe('OpenAIEmbeddingsTransformer', () => { + const transformer = new OpenAIEmbeddingsTransformer(); describe('transformRequest', () => { - test('should pass through request unchanged', async () => { + test('should pass through originalBody with model override', async () => { const request = { model: 'text-embedding-3-small', input: 'Test text', - encoding_format: 'float' as const, - dimensions: 512, + originalBody: { + model: 'text-embedding-3-small', + input: 'Test text', + encoding_format: 'float', + dimensions: 512, + }, }; - const result = await transformer.transformRequest(request); + const result = await transformer.transformRequest(request as any); - expect(result.model).toBe(request.model); - expect(result.input).toBe(request.input); - expect(result.encoding_format).toBe(request.encoding_format); - expect(result.dimensions).toBe(request.dimensions); + expect(result.model).toBe('text-embedding-3-small'); + expect(result.input).toBe('Test text'); + expect(result.encoding_format).toBe('float'); + expect(result.dimensions).toBe(512); }); test('should handle array input', async () => { const request = { model: 'text-embedding-3-small', input: ['A', 'B', 'C'], + originalBody: { + model: 'text-embedding-3-small', + input: ['A', 'B', 'C'], + }, }; - const result = await transformer.transformRequest(request); + const result = await transformer.transformRequest(request as any); expect(result.input).toEqual(['A', 'B', 'C']); }); @@ -99,7 +65,7 @@ describe('EmbeddingsTransformer', () => { expect(result.data).toHaveLength(1); expect(result.data[0]!.embedding).toEqual([0.1, 0.2, 0.3]); expect(result.model).toBe('text-embedding-3-small'); - expect(result.usage.prompt_tokens).toBe(5); + expect(result.usage!.prompt_tokens).toBe(5); }); test('should transform batch embedding response', async () => { @@ -162,7 +128,7 @@ describe('EmbeddingsTransformer', () => { describe('properties', () => { test('should have correct name', () => { - expect(transformer.name).toBe('embeddings'); + expect(transformer.name).toBe('openai'); }); test('should have correct endpoint', () => { diff --git a/packages/backend/src/transformers/embeddings.ts b/packages/backend/src/transformers/embeddings.ts deleted file mode 100644 index ed1e28029..000000000 --- a/packages/backend/src/transformers/embeddings.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { UnifiedEmbeddingsRequest, UnifiedEmbeddingsResponse } from '../types/unified'; - -/** - * EmbeddingsTransformer - * - * Simple pass-through transformer for embeddings since the API format - * is standardized across all providers (OpenAI, Voyage, Cohere, Google, etc.) - */ -export class EmbeddingsTransformer { - name = 'embeddings'; - defaultEndpoint = '/embeddings'; - - async parseRequest(input: any): Promise { - return { - model: input.model, - input: input.input, - encoding_format: input.encoding_format, - dimensions: input.dimensions, - user: input.user, - }; - } - - async transformRequest(request: UnifiedEmbeddingsRequest): Promise { - // Pass-through - embeddings API is standardized across providers - return { - model: request.model, - input: request.input, - encoding_format: request.encoding_format, - dimensions: request.dimensions, - user: request.user, - }; - } - - async transformResponse(response: any): Promise { - return { - object: 'list', - data: response.data, - model: response.model, - usage: response.usage, - }; - } - - async formatResponse(response: UnifiedEmbeddingsResponse): Promise { - // Pass through - already in correct format - return { - object: response.object, - data: response.data, - model: response.model, - usage: response.usage, - }; - } - - /** - * Embeddings don't support streaming, so this returns undefined - */ - extractUsage(eventData: string) { - return undefined; - } -} diff --git a/packages/backend/src/transformers/embeddings/gemini.ts b/packages/backend/src/transformers/embeddings/gemini.ts new file mode 100644 index 000000000..9241899d9 --- /dev/null +++ b/packages/backend/src/transformers/embeddings/gemini.ts @@ -0,0 +1,129 @@ +import { UnifiedEmbeddingsRequest, UnifiedEmbeddingsResponse } from '../../types/unified'; +import { EmbeddingsTransformer } from '../../types/embeddings-transformer'; + +export class GeminiEmbeddingsTransformer implements EmbeddingsTransformer { + readonly name = 'gemini'; + readonly defaultEndpoint = '/v1beta/models/:model:embedContent'; + + getEndpoint(request: UnifiedEmbeddingsRequest): string { + const model = GeminiEmbeddingsTransformer.prefixModel(request.model); + const input = request.input; + const isBatch = Array.isArray(input) && input.length > 1; + const action = isBatch ? 'batchEmbedContents' : 'embedContent'; + return `/v1beta/${model}:${action}`; + } + + getAuthHeaders(apiKey: string, headers: Record): void { + headers['x-goog-api-key'] = apiKey; + } + + async transformRequest(request: UnifiedEmbeddingsRequest): Promise { + const model = GeminiEmbeddingsTransformer.prefixModel(request.model); + const input = request.input; + const isBatch = Array.isArray(input) && input.length > 1; + + if (isBatch) { + return { + requests: (input as string[]).map((text) => + GeminiEmbeddingsTransformer.buildRequestPayload(model, text, request) + ), + }; + } + + // Single input (string or single-element array) + const text = Array.isArray(input) ? input[0] : input; + if (text === undefined) { + throw new Error('Gemini embeddings input array must contain at least one item'); + } + return GeminiEmbeddingsTransformer.buildRequestPayload(model, text, request); + } + + async transformResponse( + response: any, + request?: UnifiedEmbeddingsRequest + ): Promise { + const modelName = request?.model ?? ''; + if (response.embeddings) { + // BatchEmbedContentsResponse: { embeddings: [{ values, shape }] } + return { + object: 'list', + data: response.embeddings.map((item: any, index: number) => ({ + object: 'embedding' as const, + embedding: item.values, + index, + })), + model: modelName, + usage: response.usageMetadata + ? { + prompt_tokens: response.usageMetadata.promptTokenCount ?? 0, + total_tokens: response.usageMetadata.totalTokenCount ?? 0, + } + : undefined, + }; + } + + // EmbedContentResponse: { embedding: { values, shape } } + return { + object: 'list', + data: [ + { + object: 'embedding' as const, + embedding: response.embedding?.values ?? [], + index: 0, + }, + ], + model: modelName, + usage: response.usageMetadata + ? { + prompt_tokens: response.usageMetadata.promptTokenCount ?? 0, + total_tokens: response.usageMetadata.totalTokenCount ?? 0, + } + : undefined, + }; + } + + async formatResponse(response: UnifiedEmbeddingsResponse): Promise { + return { + object: response.object, + data: response.data, + model: response.model, + usage: response.usage, + }; + } + + extractUsage(_eventData: string): undefined { + return undefined; + } + + private static prefixModel(model: string): string { + if (!model.startsWith('models/') && !model.startsWith('tunedModels/')) { + return `models/${model}`; + } + return model; + } + + private static buildRequestPayload( + model: string, + text: string, + request: UnifiedEmbeddingsRequest + ): Record { + const payload: Record = { + model, + content: { parts: [{ text }] }, + }; + + if (request.originalBody?.taskType !== undefined) { + payload.taskType = request.originalBody.taskType; + } + if (request.originalBody?.title !== undefined) { + payload.title = request.originalBody.title; + } + + const outputDimensionality = request.originalBody?.outputDimensionality ?? request.dimensions; + if (outputDimensionality !== undefined) { + payload.outputDimensionality = outputDimensionality; + } + + return payload; + } +} diff --git a/packages/backend/src/transformers/embeddings/index.ts b/packages/backend/src/transformers/embeddings/index.ts new file mode 100644 index 000000000..9a62fd86c --- /dev/null +++ b/packages/backend/src/transformers/embeddings/index.ts @@ -0,0 +1,2 @@ +export * from './openai'; +export * from './gemini'; diff --git a/packages/backend/src/transformers/embeddings/openai.ts b/packages/backend/src/transformers/embeddings/openai.ts new file mode 100644 index 000000000..fc60279a4 --- /dev/null +++ b/packages/backend/src/transformers/embeddings/openai.ts @@ -0,0 +1,37 @@ +import { UnifiedEmbeddingsRequest, UnifiedEmbeddingsResponse } from '../../types/unified'; +import { EmbeddingsTransformer } from '../../types/embeddings-transformer'; + +export class OpenAIEmbeddingsTransformer implements EmbeddingsTransformer { + readonly name = 'openai'; + readonly defaultEndpoint = '/embeddings'; + + async transformRequest(request: UnifiedEmbeddingsRequest): Promise { + return { + ...request.originalBody, + model: request.model, + }; + } + + async transformResponse(response: any): Promise { + return { + ...response, + object: 'list', + data: response.data, + model: response.model, + usage: response.usage, + }; + } + + async formatResponse(response: UnifiedEmbeddingsResponse): Promise { + return { + object: response.object, + data: response.data, + model: response.model, + usage: response.usage, + }; + } + + extractUsage(_eventData: string): undefined { + return undefined; + } +} diff --git a/packages/backend/src/types/embeddings-transformer.ts b/packages/backend/src/types/embeddings-transformer.ts new file mode 100644 index 000000000..26f86f344 --- /dev/null +++ b/packages/backend/src/types/embeddings-transformer.ts @@ -0,0 +1,29 @@ +import { UnifiedEmbeddingsRequest, UnifiedEmbeddingsResponse } from './unified'; + +/** + * Embeddings transformer interface for provider-type-aware request/response transformation. + * + * Each embeddings provider type implements this interface to handle: + * - URL construction (e.g. Gemini /v1beta/models/{model}:embedContent) + * - Auth headers (e.g. x-goog-api-key for Gemini) + * - Request/response transformation between OpenAI and provider formats + */ +export interface EmbeddingsTransformer { + readonly name: string; + readonly defaultEndpoint: string; + + getEndpoint?(request: UnifiedEmbeddingsRequest): string; + + getAuthHeaders?(apiKey: string, headers: Record): void; + + transformRequest(request: UnifiedEmbeddingsRequest): Promise; + + transformResponse( + response: any, + request?: UnifiedEmbeddingsRequest + ): Promise; + + formatResponse(response: UnifiedEmbeddingsResponse): Promise; + + extractUsage(eventData: string): { prompt_tokens?: number } | undefined; +} diff --git a/packages/backend/src/types/unified.ts b/packages/backend/src/types/unified.ts index d12960e44..dd139c90b 100644 --- a/packages/backend/src/types/unified.ts +++ b/packages/backend/src/types/unified.ts @@ -285,7 +285,7 @@ export interface UnifiedEmbeddingsResponse { index: number; }>; model: string; - usage: { + usage?: { prompt_tokens: number; total_tokens: number; };