diff --git a/src/__tests__/config.test.ts b/src/__tests__/config.test.ts index f3e5f67..c57fba6 100644 --- a/src/__tests__/config.test.ts +++ b/src/__tests__/config.test.ts @@ -29,6 +29,11 @@ describe("config", () => { expect(getConfig().ai?.gateway).toBe("openrouter"); }); + it("configure sets ai.gateway to litellm", () => { + configure({ ai: { gateway: "litellm" } }); + expect(getConfig().ai?.gateway).toBe("litellm"); + }); + it("configure merges without overwriting other keys", () => { configure({ uploadBasePath: "./uploads" }); configure({ ai: { gateway: "none" } }); diff --git a/src/__tests__/models.test.ts b/src/__tests__/models.test.ts new file mode 100644 index 0000000..ce333a3 --- /dev/null +++ b/src/__tests__/models.test.ts @@ -0,0 +1,50 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; +import { resolveModel } from "../models"; +import { configure, resetConfig } from "../config"; +import { ConfigurationError } from "../errors"; + +describe("models", () => { + const originalEnv = process.env; + + beforeEach(() => { + vi.resetModules(); + process.env = { ...originalEnv }; + resetConfig(); + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it("throws ConfigurationError if litellm gateway is used but LITELLM_BASE_URL is missing", () => { + configure({ ai: { gateway: "litellm" } }); + delete process.env.LITELLM_BASE_URL; + + expect(() => resolveModel("google/gemini-3-flash")).toThrow(ConfigurationError); + expect(() => resolveModel("google/gemini-3-flash")).toThrow(/LITELLM_BASE_URL isn't set/); + }); + + it("resolves model using litellm gateway when LITELLM_BASE_URL is provided", () => { + configure({ ai: { gateway: "litellm" } }); + process.env.LITELLM_BASE_URL = "http://localhost:4000/v1"; + // It should not throw and return a wrapped language model + const model = resolveModel("google/gemini-3-flash"); + expect(model).toBeDefined(); + // Vercel AI SDK language models are objects with an execute property among others + expect(typeof model).toBe("object"); + }); + + it("remaps google prefixes to gemini for litellm", () => { + configure({ ai: { gateway: "litellm" } }); + process.env.LITELLM_BASE_URL = "http://localhost:4000/v1"; + + // We can't easily inspect the internal modelId of the wrapped model in a unit test + // without mocking the openai package, but we can verify it doesn't crash + const model = resolveModel("google/gemini-3-flash"); + expect(model).toBeDefined(); + + // And standard names should work too + const bedrockModel = resolveModel("bedrock/claude-3-sonnet"); + expect(bedrockModel).toBeDefined(); + }); +}); diff --git a/src/config.ts b/src/config.ts index 40da346..dfc03ee 100644 --- a/src/config.ts +++ b/src/config.ts @@ -9,7 +9,7 @@ export type EmailProvider = { extractContent: (params: { email: string; prompt: string }) => Promise; }; -export type AIGateway = "vercel" | "openrouter" | "cloudflare" | "none"; +export type AIGateway = "vercel" | "openrouter" | "cloudflare" | "litellm" | "none"; /** * Execution mode for browser automation. diff --git a/src/models.ts b/src/models.ts index afc6fff..43876c4 100644 --- a/src/models.ts +++ b/src/models.ts @@ -1,6 +1,7 @@ import { AIModelError, ConfigurationError } from "./errors"; import { createAnthropic } from "@ai-sdk/anthropic"; import { createGoogleGenerativeAI } from "@ai-sdk/google"; +import { createOpenAI } from "@ai-sdk/openai"; import { createOpenRouter } from "@openrouter/ai-sdk-provider"; import { gateway, type LanguageModel } from "ai"; import { wrapAISDKModel } from "axiom/ai"; @@ -16,12 +17,13 @@ let _anthropic: ReturnType | null = null; let _openrouter: ReturnType | null = null; let _cloudflareGoogle: ReturnType | null = null; let _cloudflareAnthropic: ReturnType | null = null; +let _litellm: ReturnType | null = null; function getGoogleProvider() { if (!_google) { if (!process.env.GOOGLE_GENERATIVE_AI_API_KEY) { throw new ConfigurationError( - "GOOGLE_GENERATIVE_AI_API_KEY isn't set. Add it to your environment (for example: export GOOGLE_GENERATIVE_AI_API_KEY=your_key), or use a gateway: configure({ ai: { gateway: 'vercel' } }) with AI_GATEWAY_API_KEY, configure({ ai: { gateway: 'openrouter' } }) with OPENROUTER_API_KEY, or configure({ ai: { gateway: 'cloudflare' } }) with CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_AI_GATEWAY, GOOGLE_GENERATIVE_AI_API_KEY, and CLOUDFLARE_AI_GATEWAY_API_KEY. See .env.example for reference.", + "GOOGLE_GENERATIVE_AI_API_KEY isn't set. Add it to your environment (for example: export GOOGLE_GENERATIVE_AI_API_KEY=your_key), or use a gateway: configure({ ai: { gateway: 'vercel' } }) with AI_GATEWAY_API_KEY, configure({ ai: { gateway: 'openrouter' } }) with OPENROUTER_API_KEY, configure({ ai: { gateway: 'litellm' } }) with LITELLM_BASE_URL, or configure({ ai: { gateway: 'cloudflare' } }) with CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_AI_GATEWAY, GOOGLE_GENERATIVE_AI_API_KEY, and CLOUDFLARE_AI_GATEWAY_API_KEY. See .env.example for reference.", ); } _google = createGoogleGenerativeAI({ @@ -35,7 +37,7 @@ function getAnthropicProvider() { if (!_anthropic) { if (!process.env.ANTHROPIC_API_KEY) { throw new ConfigurationError( - "ANTHROPIC_API_KEY isn't set. Add it to your environment (for example: export ANTHROPIC_API_KEY=your_key), or use a gateway: configure({ ai: { gateway: 'vercel' } }) with AI_GATEWAY_API_KEY, configure({ ai: { gateway: 'openrouter' } }) with OPENROUTER_API_KEY, or configure({ ai: { gateway: 'cloudflare' } }) with CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_AI_GATEWAY, ANTHROPIC_API_KEY, and CLOUDFLARE_AI_GATEWAY_API_KEY. See .env.example for reference.", + "ANTHROPIC_API_KEY isn't set. Add it to your environment (for example: export ANTHROPIC_API_KEY=your_key), or use a gateway: configure({ ai: { gateway: 'vercel' } }) with AI_GATEWAY_API_KEY, configure({ ai: { gateway: 'openrouter' } }) with OPENROUTER_API_KEY, configure({ ai: { gateway: 'litellm' } }) with LITELLM_BASE_URL, or configure({ ai: { gateway: 'cloudflare' } }) with CLOUDFLARE_ACCOUNT_ID, CLOUDFLARE_AI_GATEWAY, ANTHROPIC_API_KEY, and CLOUDFLARE_AI_GATEWAY_API_KEY. See .env.example for reference.", ); } _anthropic = createAnthropic({ @@ -59,6 +61,21 @@ function getOpenRouterProvider() { return _openrouter; } +function getLiteLLMProvider() { + if (!_litellm) { + if (!process.env.LITELLM_BASE_URL) { + throw new ConfigurationError( + "LITELLM_BASE_URL isn't set. To use the LiteLLM gateway, add LITELLM_BASE_URL to your environment (e.g. export LITELLM_BASE_URL=http://localhost:4000/v1). You may also need to set LITELLM_API_KEY.", + ); + } + _litellm = createOpenAI({ + apiKey: process.env.LITELLM_API_KEY || "dummy-key", + baseURL: process.env.LITELLM_BASE_URL, + }); + } + return _litellm; +} + /** * Builds the per-provider Cloudflare AI Gateway base URL and (optional) * `cf-aig-authorization` header. We route through Cloudflare's native @@ -148,6 +165,13 @@ function resolveOpenRouterModelId(modelId: string): string { return OPENROUTER_MODEL_ALIASES[modelId] ?? modelId; } +function resolveLiteLLMModelId(modelId: string): string { + if (modelId.startsWith("google/")) { + return modelId.replace("google/", "gemini/"); + } + return modelId; +} + /** * Resolves a canonical model ID to a LanguageModel instance wrapped with Axiom instrumentation. * Input format: "provider/model-name" (e.g. "google/gemini-3-flash") @@ -180,6 +204,10 @@ export function resolveModel(modelId: string): LanguageModel { return wrapModel(getOpenRouterProvider()(resolveOpenRouterModelId(modelId))); } + if (gatewayConfig === "litellm") { + return wrapModel(getLiteLLMProvider()(resolveLiteLLMModelId(modelId))); + } + const [provider, ...rest] = modelId.split("/"); const modelName = rest.join("/");