Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/loud-months-shake.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"workers-ai-provider": patch
---

Support for AI Gateway
4 changes: 3 additions & 1 deletion examples/workers-ai/src/client/components/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useChat } from "@ai-sdk/react";
import { DefaultChatTransport } from "ai";
import { useState, useRef, useEffect, useMemo } from "react";
import { useConfig } from "../config";
import { useUniqueId } from "../utils/useUniqueId";
import { chatModels } from "./models";

export function Chat() {
Expand Down Expand Up @@ -30,6 +31,7 @@ export function Chat() {

function ChatSession({ model }: { model: string }) {
const { headers } = useConfig();
const chatId = useUniqueId({ model, headers }, "chat");

const transport = useMemo(
() =>
Expand All @@ -41,7 +43,7 @@ function ChatSession({ model }: { model: string }) {
[model, headers],
);

const { messages, sendMessage, status, error } = useChat({ transport });
const { messages, sendMessage, status, error } = useChat({ id: chatId, transport });
const [input, setInput] = useState("");
const isLoading = status === "streaming" || status === "submitted";

Expand Down
17 changes: 17 additions & 0 deletions examples/workers-ai/src/client/utils/useUniqueId.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { useMemo } from "react";

/**
* Generates a unique ID based on the provided data object.
* Useful for creating stable identifiers from model + headers combinations.
*/
export function useUniqueId(data: Record<string, unknown>, prefix = "id"): string {
return useMemo(() => {
const serialized = JSON.stringify(data);
let hash = 0;
for (let i = 0; i < serialized.length; i++) {
hash = (hash << 5) - hash + serialized.charCodeAt(i);
hash |= 0;
}
return `${prefix}-${Math.abs(hash).toString(36)}`;
}, [data, prefix]);
}
38 changes: 32 additions & 6 deletions packages/workers-ai-provider/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ export function createRun(config: CreateRunConfig): AiRun {
options?: AiOptions & Record<string, unknown>,
): Promise<Response | ReadableStream<Uint8Array> | AiModels[Name]["postProcessedOutputs"]> {
const {
gateway: _gateway,
gateway,
prefix: _prefix,
extraHeaders: _extraHeaders,
returnRawResponse,
Expand Down Expand Up @@ -137,15 +137,37 @@ export function createRun(config: CreateRunConfig): AiRun {
}

const queryString = urlParams.toString();
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${
queryString ? `?${queryString}` : ""
}`;

const headers = {
// Build URL: use AI Gateway if gateway option is provided, otherwise direct API
const url = gateway?.id
? `https://gateway.ai.cloudflare.com/v1/${accountId}/${gateway.id}/workers-ai/run/${model}${
queryString ? `?${queryString}` : ""
}`
: `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${model}${
queryString ? `?${queryString}` : ""
}`;

// Build headers with optional gateway cache headers
const headers: Record<string, string> = {
Authorization: `Bearer ${apiKey}`,
"Content-Type": "application/json",
};

if (gateway) {
if (gateway.skipCache) {
headers["cf-aig-skip-cache"] = "true";
}
if (typeof gateway.cacheTtl === "number") {
headers["cf-aig-cache-ttl"] = String(gateway.cacheTtl);
}
if (gateway.cacheKey) {
headers["cf-aig-cache-key"] = gateway.cacheKey;
}
if (gateway.metadata) {
headers["cf-aig-metadata"] = JSON.stringify(gateway.metadata);
}
}

const body = JSON.stringify(inputs);

const response = await fetch(url, {
Expand Down Expand Up @@ -186,8 +208,12 @@ export function createRun(config: CreateRunConfig): AiRun {
// endpoint and return a JSON response with empty result instead of SSE.
// Retry without streaming so doStream's graceful degradation path can
// wrap the complete response as a synthetic stream.
// Use the same URL (gateway or direct) as the original request.
const retryResponse = await fetch(url, {
body: JSON.stringify({ ...(inputs as Record<string, unknown>), stream: false }),
body: JSON.stringify({
...(inputs as Record<string, unknown>),
stream: false,
}),
headers,
method: "POST",
signal: signal as AbortSignal | undefined,
Expand Down
188 changes: 187 additions & 1 deletion packages/workers-ai-provider/test/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { describe, expect, it } from "vitest";
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import {
processPartialToolCalls,
processToolCalls,
processText,
sanitizeToolCallId,
normalizeMessagesForBinding,
prepareToolsAndToolChoice,
createRun,
} from "../src/utils";

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -404,3 +405,188 @@ describe("processText", () => {
).toBe("From choices");
});
});

// ---------------------------------------------------------------------------
// createRun - gateway support
// ---------------------------------------------------------------------------

describe("createRun", () => {
const originalFetch = globalThis.fetch;

beforeEach(() => {
globalThis.fetch = vi.fn();
});

afterEach(() => {
globalThis.fetch = originalFetch;
vi.restoreAllMocks();
});

it("should use direct API URL when no gateway is provided", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" });

expect(globalThis.fetch).toHaveBeenCalledWith(
"https://api.cloudflare.com/client/v4/accounts/test-account/ai/run/@cf/meta/llama-3.1-8b-instruct",
expect.objectContaining({
method: "POST",
headers: {
Authorization: "Bearer test-key",
"Content-Type": "application/json",
},
}),
);
});

it("should use gateway URL when gateway.id is provided", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" }, { gateway: { id: "my-gateway" } });

expect(globalThis.fetch).toHaveBeenCalledWith(
"https://gateway.ai.cloudflare.com/v1/test-account/my-gateway/workers-ai/run/@cf/meta/llama-3.1-8b-instruct",
expect.objectContaining({
method: "POST",
headers: {
Authorization: "Bearer test-key",
"Content-Type": "application/json",
},
}),
);
});

it("should add cf-aig-skip-cache header when skipCache is true", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" }, { gateway: { id: "my-gateway", skipCache: true } });

expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
headers: expect.objectContaining({
"cf-aig-skip-cache": "true",
}),
}),
);
});

it("should add cf-aig-cache-ttl header when cacheTtl is provided", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" }, { gateway: { id: "my-gateway", cacheTtl: 3600 } });

expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
headers: expect.objectContaining({
"cf-aig-cache-ttl": "3600",
}),
}),
);
});

it("should add cf-aig-cache-key header when cacheKey is provided", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" }, { gateway: { id: "my-gateway", cacheKey: "my-custom-key" } });

expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
headers: expect.objectContaining({
"cf-aig-cache-key": "my-custom-key",
}),
}),
);
});

it("should add cf-aig-metadata header when metadata is provided", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run("@cf/meta/llama-3.1-8b-instruct" as any, { prompt: "Hi" }, { gateway: { id: "my-gateway", metadata: { user: "test", session: 123 } } });

expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
headers: expect.objectContaining({
"cf-aig-metadata": '{"user":"test","session":123}',
}),
}),
);
});

it("should add all gateway cache headers when all options are provided", async () => {
const mockResponse = {
ok: true,
json: vi.fn().mockResolvedValue({ result: { response: "Hello" } }),
headers: new Headers({ "content-type": "application/json" }),
};
vi.mocked(globalThis.fetch).mockResolvedValue(mockResponse as unknown as Response);

const run = createRun({ accountId: "test-account", apiKey: "test-key" });
await run(
"@cf/meta/llama-3.1-8b-instruct" as any,
{ prompt: "Hi" },
{
gateway: {
id: "my-gateway",
skipCache: true,
cacheTtl: 7200,
cacheKey: "custom-key",
metadata: { env: "prod" },
},
},
);

expect(globalThis.fetch).toHaveBeenCalledWith(
"https://gateway.ai.cloudflare.com/v1/test-account/my-gateway/workers-ai/run/@cf/meta/llama-3.1-8b-instruct",
expect.objectContaining({
headers: {
Authorization: "Bearer test-key",
"Content-Type": "application/json",
"cf-aig-skip-cache": "true",
"cf-aig-cache-ttl": "7200",
"cf-aig-cache-key": "custom-key",
"cf-aig-metadata": '{"env":"prod"}',
},
}),
);
});
});
Loading