diff --git a/.changeset/ai-gateway-embeddings.md b/.changeset/ai-gateway-embeddings.md new file mode 100644 index 000000000..02dd49a4b --- /dev/null +++ b/.changeset/ai-gateway-embeddings.md @@ -0,0 +1,36 @@ +--- +"ai-gateway-provider": minor +--- + +Add embedding model support to AI Gateway provider. + +This allows users to route embedding requests through AI Gateway, enabling: +- Caching for embeddings via `cacheTtl` and `cacheKey` options +- Request logging via `collectLog` option +- Retry configuration via `retries` option +- Metadata tracking via `metadata` option + +Usage: +```typescript +import { createAiGateway } from "ai-gateway-provider"; +import { createOpenAI } from "@ai-sdk/openai"; +import { embed } from "ai"; + +const aigateway = createAiGateway({ + accountId: "your-account-id", + apiKey: "your-api-key", + gateway: "your-gateway", +}); + +const openai = createOpenAI({ apiKey: "your-openai-key" }); + +const result = await embed({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + value: "Hello, world!", +}); +``` + +New methods on the AI Gateway provider: +- `embedding()` - Create an embedding model routed through AI Gateway +- `textEmbedding()` - Alias for `embedding()` +- `textEmbeddingModel()` - Alias for `embedding()` diff --git a/packages/ai-gateway-provider/src/ai-gateway-embedding-model.ts b/packages/ai-gateway-provider/src/ai-gateway-embedding-model.ts new file mode 100644 index 000000000..c9b203331 --- /dev/null +++ b/packages/ai-gateway-provider/src/ai-gateway-embedding-model.ts @@ -0,0 +1,216 @@ +import type { EmbeddingModelV3 } from "@ai-sdk/provider"; +import type { FetchFunction } from "@ai-sdk/provider-utils"; +import { CF_TEMP_TOKEN } from "./auth"; +import { providers } from "./providers"; +import { + AiGatewayDoesNotExist, + AiGatewayInternalFetchError, + AiGatewayUnauthorizedError, + parseAiGatewayOptions, + streamToObject, + type AiGatewaySettings, +} from "./shared"; + +type InternalEmbeddingModelV3 = EmbeddingModelV3 & { + config?: { fetch?: FetchFunction | undefined }; +}; + +export class AiGatewayEmbeddingModel implements EmbeddingModelV3 { + readonly specificationVersion = "v3"; + + readonly models: InternalEmbeddingModelV3[]; + readonly config: AiGatewaySettings; + + get modelId(): string { + if (!this.models[0]) { + throw new Error("models cannot be empty array"); + } + + return this.models[0].modelId; + } + + get provider(): string { + if (!this.models[0]) { + throw new Error("models cannot be empty array"); + } + + return this.models[0].provider; + } + + get maxEmbeddingsPerCall(): PromiseLike | number | undefined { + if (!this.models[0]) { + throw new Error("models cannot be empty array"); + } + + return this.models[0].maxEmbeddingsPerCall; + } + + get supportsParallelCalls(): PromiseLike | boolean { + if (!this.models[0]) { + throw new Error("models cannot be empty array"); + } + + return this.models[0].supportsParallelCalls; + } + + constructor(models: EmbeddingModelV3[], config: AiGatewaySettings) { + this.models = models; + this.config = config; + } + + async doEmbed( + options: Parameters[0], + ): Promise>> { + const requests: { url: string; request: Request; modelProvider: string }[] = []; + + // Model configuration and request collection + for (const model of this.models) { + if (!model.config || !Object.keys(model.config).includes("fetch")) { + throw new Error( + `Sorry, but provider "${model.provider}" is currently not supported for embeddings, please open an issue in the github repo!`, + ); + } + + model.config.fetch = (url, request) => { + requests.push({ + modelProvider: model.provider, + request: request as Request, + url: url as string, + }); + throw new AiGatewayInternalFetchError("Stopping provider execution..."); + }; + + try { + await model.doEmbed(options); + } catch (e) { + if (!(e instanceof AiGatewayInternalFetchError)) { + throw e; + } + } + } + + // Process requests + const body = await Promise.all( + requests.map(async (req) => { + let providerConfig = null; + for (const provider of providers) { + if (provider.regex.test(req.url)) { + providerConfig = provider; + } + } + + if (!providerConfig) { + throw new Error( + `Sorry, but provider "${req.modelProvider}" is currently not supported for embeddings, please open an issue in the github repo!`, + ); + } + + if (!req.request.body) { + throw new Error("AI Gateway provider received an unexpected empty body"); + } + + // For AI Gateway BYOK / unified billing requests + // delete the fake injected CF_TEMP_TOKEN + const authHeader = providerConfig.headerKey ?? "authorization"; + const authValue = + "get" in req.request.headers + ? req.request.headers.get(authHeader) + : req.request.headers[authHeader]; + if (authValue?.indexOf(CF_TEMP_TOKEN) !== -1) { + if ("delete" in req.request.headers) { + req.request.headers.delete(authHeader); + } else { + delete req.request.headers[authHeader]; + } + } + + return { + endpoint: providerConfig.transformEndpoint(req.url), + headers: req.request.headers, + provider: providerConfig.name, + query: await streamToObject(req.request.body), + }; + }), + ); + + // Handle response + const headers = parseAiGatewayOptions(this.config.options ?? {}); + let resp: Response; + + if ("binding" in this.config) { + const updatedBody = body.map((obj) => ({ + ...obj, + headers: { + ...(obj.headers ?? {}), + ...Object.fromEntries(headers.entries()), + }, + })); + resp = await this.config.binding.run(updatedBody); + } else { + headers.set("Content-Type", "application/json"); + headers.set("cf-aig-authorization", `Bearer ${this.config.apiKey}`); + resp = await fetch( + `https://gateway.ai.cloudflare.com/v1/${this.config.accountId}/${this.config.gateway}`, + { + body: JSON.stringify(body), + headers: headers, + method: "POST", + }, + ); + } + + // Error handling + if (resp.status === 400) { + const cloneResp = resp.clone(); + const result: { + success?: boolean; + error?: { code: number; message: string }[]; + } = await cloneResp.json(); + if ( + result.success === false && + result.error && + result.error.length > 0 && + result.error[0]?.code === 2001 + ) { + throw new AiGatewayDoesNotExist("This AI gateway does not exist"); + } + } else if (resp.status === 401) { + const cloneResp = resp.clone(); + const result: { + success?: boolean; + error?: { code: number; message: string }[]; + } = await cloneResp.json(); + if ( + result.success === false && + result.error && + result.error.length > 0 && + result.error[0]?.code === 2009 + ) { + throw new AiGatewayUnauthorizedError( + "Your AI Gateway has authentication active, but you didn't provide a valid apiKey", + ); + } + } + + const step = Number.parseInt(resp.headers.get("cf-aig-step") ?? "0", 10); + if (!this.models[step]) { + throw new Error("Unexpected AI Gateway Error"); + } + + this.models[step].config = { + ...this.models[step].config, + fetch: (_url, _req) => resp as unknown as Promise, + }; + + const result = await this.models[step].doEmbed(options); + + // Ensure V3 compliance: warnings field is required + return { + embeddings: result.embeddings, + usage: result.usage, + providerMetadata: result.providerMetadata, + response: result.response, + warnings: result.warnings ?? [], + }; + } +} diff --git a/packages/ai-gateway-provider/src/index.ts b/packages/ai-gateway-provider/src/index.ts index c9fa82bae..0734008d6 100644 --- a/packages/ai-gateway-provider/src/index.ts +++ b/packages/ai-gateway-provider/src/index.ts @@ -1,18 +1,36 @@ -import type { LanguageModelV3 } from "@ai-sdk/provider"; +import type { EmbeddingModelV3, LanguageModelV3 } from "@ai-sdk/provider"; import type { FetchFunction } from "@ai-sdk/provider-utils"; import { CF_TEMP_TOKEN } from "./auth"; import { providers } from "./providers"; +import { AiGatewayEmbeddingModel } from "./ai-gateway-embedding-model"; +import { + AiGatewayDoesNotExist, + AiGatewayInternalFetchError, + AiGatewayUnauthorizedError, + parseAiGatewayOptions, + streamToObject, + type AiGatewayAPISettings, + type AiGatewayBindingSettings, + type AiGatewayOptions, + type AiGatewayRetries, + type AiGatewaySettings, +} from "./shared"; + +// Re-export errors and types from shared +export { + AiGatewayDoesNotExist, + AiGatewayInternalFetchError, + AiGatewayUnauthorizedError, + parseAiGatewayOptions, + type AiGatewayAPISettings, + type AiGatewayBindingSettings, + type AiGatewayOptions, + type AiGatewayRetries, + type AiGatewaySettings, +}; -export class AiGatewayInternalFetchError extends Error {} - -export class AiGatewayDoesNotExist extends Error {} - -export class AiGatewayUnauthorizedError extends Error {} - -async function streamToObject(stream: ReadableStream) { - const response = new Response(stream); - return await response.json(); -} +// Re-export embedding model +export { AiGatewayEmbeddingModel }; type InternalLanguageModelV3 = LanguageModelV3 & { config?: { fetch?: FetchFunction | undefined }; @@ -218,91 +236,34 @@ export interface AiGateway { (models: LanguageModelV3 | LanguageModelV3[]): LanguageModelV3; chat(models: LanguageModelV3 | LanguageModelV3[]): LanguageModelV3; + + embedding(models: EmbeddingModelV3 | EmbeddingModelV3[]): EmbeddingModelV3; + + textEmbedding(models: EmbeddingModelV3 | EmbeddingModelV3[]): EmbeddingModelV3; + + textEmbeddingModel(models: EmbeddingModelV3 | EmbeddingModelV3[]): EmbeddingModelV3; } -export type AiGatewayReties = { - maxAttempts?: 1 | 2 | 3 | 4 | 5; - retryDelayMs?: number; - backoff?: "constant" | "linear" | "exponential"; -}; -export type AiGatewayOptions = { - cacheKey?: string; - cacheTtl?: number; - skipCache?: boolean; - metadata?: Record; - collectLog?: boolean; - eventId?: string; - requestTimeoutMs?: number; - retries?: AiGatewayReties; -}; -export type AiGatewayAPISettings = { - gateway: string; - accountId: string; - apiKey?: string; - options?: AiGatewayOptions; -}; -export type AiGatewayBindingSettings = { - binding: { - run(data: unknown): Promise; - }; - options?: AiGatewayOptions; -}; -export type AiGatewaySettings = AiGatewayAPISettings | AiGatewayBindingSettings; +/** + * @deprecated Use `AiGatewayRetries` instead + */ +export type AiGatewayReties = AiGatewayRetries; export function createAiGateway(options: AiGatewaySettings): AiGateway { const createChatModel = (models: LanguageModelV3 | LanguageModelV3[]) => { return new AiGatewayChatLanguageModel(Array.isArray(models) ? models : [models], options); }; + const createEmbeddingModel = (models: EmbeddingModelV3 | EmbeddingModelV3[]) => { + return new AiGatewayEmbeddingModel(Array.isArray(models) ? models : [models], options); + }; + const provider = (models: LanguageModelV3 | LanguageModelV3[]) => createChatModel(models); provider.chat = createChatModel; + provider.embedding = createEmbeddingModel; + provider.textEmbedding = createEmbeddingModel; + provider.textEmbeddingModel = createEmbeddingModel; return provider; } - -export function parseAiGatewayOptions(options: AiGatewayOptions): Headers { - const headers = new Headers(); - - if (options.skipCache === true) { - headers.set("cf-skip-cache", "true"); - } - - if (options.cacheTtl) { - headers.set("cf-cache-ttl", options.cacheTtl.toString()); - } - - if (options.metadata) { - headers.set("cf-aig-metadata", JSON.stringify(options.metadata)); - } - - if (options.cacheKey) { - headers.set("cf-aig-cache-key", options.cacheKey); - } - - if (options.collectLog !== undefined) { - headers.set("cf-aig-collect-log", options.collectLog === true ? "true" : "false"); - } - - if (options.eventId !== undefined) { - headers.set("cf-aig-event-id", options.eventId); - } - - if (options.requestTimeoutMs !== undefined) { - headers.set("cf-aig-request-timeout", options.requestTimeoutMs.toString()); - } - - if (options.retries !== undefined) { - if (options.retries.maxAttempts !== undefined) { - headers.set("cf-aig-max-attempts", options.retries.maxAttempts.toString()); - } - if (options.retries.retryDelayMs !== undefined) { - headers.set("cf-aig-retry-delay", options.retries.retryDelayMs.toString()); - } - if (options.retries.backoff !== undefined) { - headers.set("cf-aig-backoff", options.retries.backoff); - } - } - - return headers; -} diff --git a/packages/ai-gateway-provider/src/shared.ts b/packages/ai-gateway-provider/src/shared.ts new file mode 100644 index 000000000..cef6ff4de --- /dev/null +++ b/packages/ai-gateway-provider/src/shared.ts @@ -0,0 +1,89 @@ +export class AiGatewayInternalFetchError extends Error {} + +export class AiGatewayDoesNotExist extends Error {} + +export class AiGatewayUnauthorizedError extends Error {} + +export async function streamToObject(stream: ReadableStream) { + const response = new Response(stream); + return await response.json(); +} + +export type AiGatewayRetries = { + maxAttempts?: 1 | 2 | 3 | 4 | 5; + retryDelayMs?: number; + backoff?: "constant" | "linear" | "exponential"; +}; + +export type AiGatewayOptions = { + cacheKey?: string; + cacheTtl?: number; + skipCache?: boolean; + metadata?: Record; + collectLog?: boolean; + eventId?: string; + requestTimeoutMs?: number; + retries?: AiGatewayRetries; +}; + +export type AiGatewayAPISettings = { + gateway: string; + accountId: string; + apiKey?: string; + options?: AiGatewayOptions; +}; + +export type AiGatewayBindingSettings = { + binding: { + run(data: unknown): Promise; + }; + options?: AiGatewayOptions; +}; + +export type AiGatewaySettings = AiGatewayAPISettings | AiGatewayBindingSettings; + +export function parseAiGatewayOptions(options: AiGatewayOptions): Headers { + const headers = new Headers(); + + if (options.skipCache === true) { + headers.set("cf-skip-cache", "true"); + } + + if (options.cacheTtl) { + headers.set("cf-cache-ttl", options.cacheTtl.toString()); + } + + if (options.metadata) { + headers.set("cf-aig-metadata", JSON.stringify(options.metadata)); + } + + if (options.cacheKey) { + headers.set("cf-aig-cache-key", options.cacheKey); + } + + if (options.collectLog !== undefined) { + headers.set("cf-aig-collect-log", options.collectLog === true ? "true" : "false"); + } + + if (options.eventId !== undefined) { + headers.set("cf-aig-event-id", options.eventId); + } + + if (options.requestTimeoutMs !== undefined) { + headers.set("cf-aig-request-timeout", options.requestTimeoutMs.toString()); + } + + if (options.retries !== undefined) { + if (options.retries.maxAttempts !== undefined) { + headers.set("cf-aig-max-attempts", options.retries.maxAttempts.toString()); + } + if (options.retries.retryDelayMs !== undefined) { + headers.set("cf-aig-retry-delay", options.retries.retryDelayMs.toString()); + } + if (options.retries.backoff !== undefined) { + headers.set("cf-aig-backoff", options.retries.backoff); + } + } + + return headers; +} diff --git a/packages/ai-gateway-provider/test/embeddings.test.ts b/packages/ai-gateway-provider/test/embeddings.test.ts new file mode 100644 index 000000000..09bd32941 --- /dev/null +++ b/packages/ai-gateway-provider/test/embeddings.test.ts @@ -0,0 +1,291 @@ +import { createOpenAI } from "@ai-sdk/openai"; +import { embed, embedMany } from "ai"; +import { HttpResponse, http } from "msw"; +import { setupServer } from "msw/node"; +import { afterAll, afterEach, beforeAll, describe, expect, it } from "vitest"; +import { AiGatewayDoesNotExist, AiGatewayUnauthorizedError, createAiGateway } from "../src"; + +const TEST_ACCOUNT_ID = "test-account-id"; +const TEST_API_KEY = "test-api-key"; +const TEST_GATEWAY = "my-gateway"; + +const embedResponse = [0.1, 0.2, 0.3, 0.4, 0.5]; +const embedHandler = http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async () => { + return HttpResponse.json({ + object: "list", + data: [ + { + object: "embedding", + index: 0, + embedding: embedResponse, + }, + ], + model: "text-embedding-3-small", + usage: { + prompt_tokens: 10, + total_tokens: 10, + }, + }); + }, +); + +const embedManyResponse = [ + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.2, 0.3, 0.4, 0.5, 0.6], + [0.3, 0.4, 0.5, 0.6, 0.7], +]; +const embedManyHandler = http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async () => { + return HttpResponse.json({ + object: "list", + data: embedManyResponse.map((embedding, index) => ({ + object: "embedding", + index, + embedding, + })), + model: "text-embedding-3-small", + usage: { + prompt_tokens: 30, + total_tokens: 30, + }, + }); + }, +); + +const server = setupServer(embedHandler); + +describe("Embedding Tests", () => { + beforeAll(() => server.listen()); + afterEach(() => server.resetHandlers()); + afterAll(() => server.close()); + + it("should embed a single value", async () => { + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + const result = await embed({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + value: "Hello, world!", + }); + expect(result.embedding).toEqual(embedResponse); + }); + + it("should embed multiple values", async () => { + server.use(embedManyHandler); + + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + const result = await embedMany({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + values: ["Hello", "World", "Test"], + }); + expect(result.embeddings).toEqual(embedManyResponse); + }); + + it("should work with textEmbedding alias", async () => { + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + const result = await embed({ + model: aigateway.textEmbedding(openai.textEmbedding("text-embedding-3-small")), + value: "Hello, world!", + }); + expect(result.embedding).toEqual(embedResponse); + }); + + it("should work with textEmbeddingModel alias", async () => { + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + const result = await embed({ + model: aigateway.textEmbeddingModel(openai.textEmbeddingModel("text-embedding-3-small")), + value: "Hello, world!", + }); + expect(result.embedding).toEqual(embedResponse); + }); +}); + +describe("Embedding with Gateway Options", () => { + beforeAll(() => server.listen()); + afterEach(() => server.resetHandlers()); + afterAll(() => server.close()); + + it("should pass gateway options through headers", async () => { + let capturedHeaders: Headers | null = null; + + server.use( + http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async ({ request }) => { + capturedHeaders = new Headers(request.headers); + return HttpResponse.json({ + object: "list", + data: [ + { + object: "embedding", + index: 0, + embedding: embedResponse, + }, + ], + model: "text-embedding-3-small", + usage: { + prompt_tokens: 10, + total_tokens: 10, + }, + }); + }, + ), + ); + + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + options: { + cacheTtl: 3600, + skipCache: true, + collectLog: true, + metadata: { userId: "test-user" }, + }, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + await embed({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + value: "Hello, world!", + }); + + expect(capturedHeaders).not.toBeNull(); + expect(capturedHeaders?.get("cf-cache-ttl")).toBe("3600"); + expect(capturedHeaders?.get("cf-skip-cache")).toBe("true"); + expect(capturedHeaders?.get("cf-aig-collect-log")).toBe("true"); + expect(capturedHeaders?.get("cf-aig-metadata")).toBe('{"userId":"test-user"}'); + }); +}); + +describe("Embedding with Binding", () => { + it("should work with a binding", async () => { + const mockBinding = { + run: async () => + new Response( + JSON.stringify({ + object: "list", + data: [ + { + object: "embedding", + index: 0, + embedding: embedResponse, + }, + ], + model: "text-embedding-3-small", + usage: { + prompt_tokens: 10, + total_tokens: 10, + }, + }), + { + headers: { "cf-aig-step": "0" }, + }, + ), + }; + + const aigateway = createAiGateway({ + binding: mockBinding, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + const result = await embed({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + value: "Hello, world!", + }); + expect(result.embedding).toEqual(embedResponse); + }); +}); + +describe("Embedding Error Handling", () => { + beforeAll(() => server.listen()); + afterEach(() => server.resetHandlers()); + afterAll(() => server.close()); + + it("should throw AiGatewayDoesNotExist for 400 status with code 2001", async () => { + server.use( + http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async () => { + return HttpResponse.json( + { + success: false, + error: [{ code: 2001, message: "Gateway not found" }], + }, + { status: 400 }, + ); + }, + ), + ); + + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + await expect( + embed({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + value: "Hello, world!", + }), + ).rejects.toThrow(AiGatewayDoesNotExist); + }); + + it("should throw AiGatewayUnauthorizedError for 401 status with code 2009", async () => { + server.use( + http.post( + `https://gateway.ai.cloudflare.com/v1/${TEST_ACCOUNT_ID}/${TEST_GATEWAY}`, + async () => { + return HttpResponse.json( + { + success: false, + error: [{ code: 2009, message: "Unauthorized" }], + }, + { status: 401 }, + ); + }, + ), + ); + + const aigateway = createAiGateway({ + accountId: TEST_ACCOUNT_ID, + apiKey: TEST_API_KEY, + gateway: TEST_GATEWAY, + }); + const openai = createOpenAI({ apiKey: TEST_API_KEY }); + + await expect( + embed({ + model: aigateway.embedding(openai.embedding("text-embedding-3-small")), + value: "Hello, world!", + }), + ).rejects.toThrow(AiGatewayUnauthorizedError); + }); +});