diff --git a/.changeset/plenty-cooks-hide.md b/.changeset/plenty-cooks-hide.md new file mode 100644 index 000000000..8438a8445 --- /dev/null +++ b/.changeset/plenty-cooks-hide.md @@ -0,0 +1,5 @@ +--- +"@cloudflare/tanstack-ai": patch +--- + +Workers AI adapter for Tanstack AI doesn't require API Key diff --git a/examples/tanstack-ai/worker/index.ts b/examples/tanstack-ai/worker/index.ts index 454ae4efc..39f85bf1f 100644 --- a/examples/tanstack-ai/worker/index.ts +++ b/examples/tanstack-ai/worker/index.ts @@ -155,13 +155,12 @@ function workersAiGatewayConfig(creds: RequestCredentials) { if (creds.useBinding) { return { binding: env.AI.gateway(resolveGatewayId(creds)), - apiKey: env.CLOUDFLARE_API_TOKEN, }; } if (creds.cloudflare) { - return { ...gwRestConfig(creds), apiKey: creds.cloudflare.apiToken }; + return gwRestConfig(creds); } - return { binding: env.AI.gateway(resolveGatewayId(creds)), apiKey: env.CLOUDFLARE_API_TOKEN }; + return { binding: env.AI.gateway(resolveGatewayId(creds)) }; } // --------------------------------------------------------------------------- diff --git a/packages/tanstack-ai/src/utils/create-fetcher.ts b/packages/tanstack-ai/src/utils/create-fetcher.ts index 0ca2c11ed..2ce0dabea 100644 --- a/packages/tanstack-ai/src/utils/create-fetcher.ts +++ b/packages/tanstack-ai/src/utils/create-fetcher.ts @@ -199,7 +199,9 @@ export function createGatewayFetch( }; if (provider === "workers-ai") { - request.endpoint = query.model as string; + if (!request.endpoint.startsWith("run/")) { + request.endpoint = `run/${query.model}`; + } delete query.model; delete query.instructions; } diff --git a/packages/tanstack-ai/test/gateway-fetch.test.ts b/packages/tanstack-ai/test/gateway-fetch.test.ts index 09be90e68..e22ea4329 100644 --- a/packages/tanstack-ai/test/gateway-fetch.test.ts +++ b/packages/tanstack-ai/test/gateway-fetch.test.ts @@ -332,7 +332,7 @@ describe("createGatewayFetch", () => { const request = mockBinding.run.mock.calls[0]![0]; expect(request.provider).toBe("workers-ai"); - expect(request.endpoint).toBe("@cf/meta/llama-3.3-70b-instruct-fp8-fast"); + expect(request.endpoint).toBe("run/@cf/meta/llama-3.3-70b-instruct-fp8-fast"); expect(request.query.model).toBeUndefined(); expect(request.query.messages).toEqual([{ role: "user", content: "Hello" }]); }); @@ -357,6 +357,44 @@ describe("createGatewayFetch", () => { expect(request.query.instructions).toBeUndefined(); expect(request.query.messages).toEqual([]); }); + + it("should preserve endpoint when it already starts with run/", async () => { + const config: AiGatewayAdapterConfig = { + binding: mockBinding, + apiKey: "test-key", + }; + const fetcher = createGatewayFetch("workers-ai", config); + + await fetcher("https://gateway.ai.cloudflare.com/v1/run/@cf/meta/llama-3.3-70b-instruct-fp8-fast", { + method: "POST", + body: JSON.stringify({ + model: "run/@cf/meta/llama-3.3-70b-instruct-fp8-fast", + messages: [{ role: "user", content: "Hello" }], + }), + }); + + const request = mockBinding.run.mock.calls[0]![0]; + expect(request.endpoint).toBe("run/@cf/meta/llama-3.3-70b-instruct-fp8-fast"); + }); + + it("should prepend run/ when endpoint does not start with run/", async () => { + const config: AiGatewayAdapterConfig = { + binding: mockBinding, + apiKey: "test-key", + }; + const fetcher = createGatewayFetch("workers-ai", config); + + await fetcher("https://api.openai.com/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "@cf/meta/llama-3.3-70b-instruct-fp8-fast", + messages: [{ role: "user", content: "Hello" }], + }), + }); + + const request = mockBinding.run.mock.calls[0]![0]; + expect(request.endpoint).toBe("run/@cf/meta/llama-3.3-70b-instruct-fp8-fast"); + }); }); describe("endpoint extraction", () => { diff --git a/packages/tanstack-ai/test/gateway-urls.test.ts b/packages/tanstack-ai/test/gateway-urls.test.ts index 4bea05623..ebaa1595f 100644 --- a/packages/tanstack-ai/test/gateway-urls.test.ts +++ b/packages/tanstack-ai/test/gateway-urls.test.ts @@ -61,7 +61,7 @@ describe("Workers AI gateway URL verification", () => { const body = JSON.parse((init as any).body as string); expect(body.provider).toBe("workers-ai"); // createGatewayFetch moves model from query to endpoint for workers-ai - expect(body.endpoint).toBe("@cf/stabilityai/stable-diffusion-xl-base-1.0"); + expect(body.endpoint).toBe("run/@cf/stabilityai/stable-diffusion-xl-base-1.0"); expect(body.query.prompt).toBe("test prompt"); }); @@ -89,7 +89,7 @@ describe("Workers AI gateway URL verification", () => { const body = JSON.parse((init as any).body as string); expect(body.provider).toBe("workers-ai"); - expect(body.endpoint).toBe("@cf/openai/whisper"); + expect(body.endpoint).toBe("run/@cf/openai/whisper"); }); it("TTS adapter sends model name in body and hits gateway URL", async () => { @@ -107,7 +107,7 @@ describe("Workers AI gateway URL verification", () => { const body = JSON.parse((init as any).body as string); expect(body.provider).toBe("workers-ai"); - expect(body.endpoint).toBe("@cf/deepgram/aura-1"); + expect(body.endpoint).toBe("run/@cf/deepgram/aura-1"); expect(body.query.text).toBe("Hello world"); }); @@ -137,7 +137,7 @@ describe("Workers AI gateway URL verification", () => { const body = JSON.parse((init as any).body as string); expect(body.provider).toBe("workers-ai"); - expect(body.endpoint).toBe("@cf/facebook/bart-large-cnn"); + expect(body.endpoint).toBe("run/@cf/facebook/bart-large-cnn"); expect(body.query.input_text).toBe("A long article..."); });