diff --git a/packages/types/src/providers/xai.ts b/packages/types/src/providers/xai.ts index 755d692f40f..a80e0d73570 100644 --- a/packages/types/src/providers/xai.ts +++ b/packages/types/src/providers/xai.ts @@ -3,10 +3,10 @@ import type { ModelInfo } from "../model.js" // https://docs.x.ai/docs/api-reference export type XAIModelId = keyof typeof xaiModels -export const xaiDefaultModelId: XAIModelId = "grok-4.20-beta-0309-reasoning" +export const xaiDefaultModelId: XAIModelId = "grok-4.20" export const xaiModels = { - "grok-4.20-beta-0309-reasoning": { + "grok-4.20": { maxTokens: 65_536, contextWindow: 2_000_000, supportsImages: true, @@ -15,21 +15,7 @@ export const xaiModels = { outputPrice: 6.0, cacheWritesPrice: 0.5, cacheReadsPrice: 0.5, - description: - "xAI's Grok 4.20 reasoning model with 2M context. Reasoning is internal (not exposed via Chat Completions API).", - includedTools: ["search_replace"], - excludedTools: ["apply_diff"], - }, - "grok-4.20-beta-0309-non-reasoning": { - maxTokens: 65_536, - contextWindow: 2_000_000, - supportsImages: true, - supportsPromptCache: true, - inputPrice: 2.0, - outputPrice: 6.0, - cacheWritesPrice: 0.5, - cacheReadsPrice: 0.5, - description: "xAI's Grok 4.20 non-reasoning model - faster inference with 2M context.", + description: "xAI's flagship Grok 4.20 model with 2M context and reasoning support via Responses API.", includedTools: ["search_replace"], excludedTools: ["apply_diff"], }, diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index c622c9d4fcf..763d10d0277 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -10,14 +10,16 @@ vitest.mock("@roo-code/telemetry", () => ({ }, })) -const mockCreate = vitest.fn() +const mockResponsesCreate = vitest.fn() vitest.mock("openai", () => { const mockConstructor = vitest.fn() return { __esModule: true, - default: mockConstructor.mockImplementation(() => ({ chat: { completions: { create: mockCreate } } })), + default: mockConstructor.mockImplementation(() => ({ + responses: { create: mockResponsesCreate }, + })), } }) @@ -28,16 +30,30 @@ import { xaiDefaultModelId, xaiModels } from "@roo-code/types" import { XAIHandler } from "../xai" +// Helper to create an async iterable from events +function mockStream(events: any[]) { + return { + [Symbol.asyncIterator]: () => { + let index = 0 + return { + async next() { + if (index < events.length) { + return { done: false, value: events[index++] } + } + return { done: true, value: undefined } + }, + } + }, + } +} + describe("XAIHandler", () => { let handler: XAIHandler beforeEach(() => { - // Reset all mocks vi.clearAllMocks() - mockCreate.mockClear() + mockResponsesCreate.mockClear() mockCaptureException.mockClear() - - // Create handler with mock handler = new XAIHandler({}) }) @@ -50,14 +66,9 @@ describe("XAIHandler", () => { }) it("should use the provided API key", () => { - // Clear mocks before this specific test vi.clearAllMocks() - - // Create a handler with our API key const xaiApiKey = "test-api-key" new XAIHandler({ xaiApiKey }) - - // Verify the OpenAI constructor was called with our API key expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ apiKey: xaiApiKey, @@ -71,111 +82,41 @@ describe("XAIHandler", () => { expect(model.info).toEqual(xaiModels[xaiDefaultModelId]) }) - test("should return specified model when valid model is provided", () => { + it("should return specified model when valid model is provided", () => { const testModelId = "grok-3" const handlerWithModel = new XAIHandler({ apiModelId: testModelId }) const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) expect(model.info).toEqual(xaiModels[testModelId]) }) - it("should include reasoning_effort parameter for mini models", async () => { - const miniModelHandler = new XAIHandler({ - apiModelId: "grok-3-mini", - reasoningEffort: "high", - }) + it("should use Responses API (client.responses.create)", async () => { + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - // Start generating a message - const messageGenerator = miniModelHandler.createMessage("test prompt", []) - await messageGenerator.next() // Start the generator + const stream = handler.createMessage("test prompt", []) + await stream.next() - // Check that reasoning_effort was included - expect(mockCreate).toHaveBeenCalledWith( + expect(mockResponsesCreate).toHaveBeenCalledWith( expect.objectContaining({ - reasoning_effort: "high", + model: xaiDefaultModelId, + instructions: "test prompt", + stream: true, + store: false, + include: ["reasoning.encrypted_content"], }), ) }) - it("should not include reasoning_effort parameter for non-mini models", async () => { - const regularModelHandler = new XAIHandler({ - apiModelId: "grok-3", - reasoningEffort: "high", - }) - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - // Start generating a message - const messageGenerator = regularModelHandler.createMessage("test prompt", []) - await messageGenerator.next() // Start the generator - - // Check call args for reasoning_effort - const calls = mockCreate.mock.calls - const lastCall = calls[calls.length - 1][0] - expect(lastCall).not.toHaveProperty("reasoning_effort") - }) - - it("completePrompt method should return text from OpenAI API", async () => { - const expectedResponse = "This is a test response" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) - - it("should handle errors in completePrompt", async () => { - const errorMessage = "API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - - await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) - }) - it("createMessage should yield text content from stream", async () => { const testContent = "This is test content" - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { content: testContent } }], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + mockResponsesCreate.mockResolvedValueOnce( + mockStream([{ type: "response.output_text.delta", delta: testContent }]), + ) - // Create and consume the stream const stream = handler.createMessage("system prompt", []) const firstChunk = await stream.next() - // Verify the content expect(firstChunk.done).toBe(false) expect(firstChunk.value).toEqual({ type: "text", @@ -186,28 +127,13 @@ describe("XAIHandler", () => { it("createMessage should yield reasoning content from stream", async () => { const testReasoning = "Test reasoning content" - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { reasoning_content: testReasoning } }], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + mockResponsesCreate.mockResolvedValueOnce( + mockStream([{ type: "response.reasoning_text.delta", delta: testReasoning }]), + ) - // Create and consume the stream const stream = handler.createMessage("system prompt", []) const firstChunk = await stream.next() - // Verify the reasoning content expect(firstChunk.done).toBe(false) expect(firstChunk.value).toEqual({ type: "reasoning", @@ -215,373 +141,158 @@ describe("XAIHandler", () => { }) }) - it("createMessage should yield usage data from stream", async () => { - // Setup mock for streaming response that includes usage data - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: {} }], // Needs to have choices array to avoid error - usage: { - prompt_tokens: 10, - completion_tokens: 20, - cache_read_input_tokens: 5, - cache_creation_input_tokens: 15, - }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + it("createMessage should yield usage data from response.completed", async () => { + mockResponsesCreate.mockResolvedValueOnce( + mockStream([ + { + type: "response.completed", + response: { + usage: { + input_tokens: 10, + output_tokens: 20, + input_tokens_details: { cached_tokens: 5 }, + output_tokens_details: { reasoning_tokens: 8 }, + }, + }, + }, + ]), + ) - // Create and consume the stream const stream = handler.createMessage("system prompt", []) const firstChunk = await stream.next() - // Verify the usage data expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 20, - cacheReadTokens: 5, - cacheWriteTokens: 15, - }) + expect(firstChunk.value).toEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 10, + outputTokens: 20, + cacheReadTokens: 5, + reasoningTokens: 8, + }), + ) }) - it("createMessage should pass correct parameters to OpenAI client", async () => { - // Setup a handler with specific model - const modelId = "grok-3" - const modelInfo = xaiModels[modelId] - const handlerWithModel = new XAIHandler({ apiModelId: modelId }) - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + it("createMessage should yield tool_call from output_item.done", async () => { + mockResponsesCreate.mockResolvedValueOnce( + mockStream([ + { + type: "response.output_item.done", + item: { + type: "function_call", + call_id: "call_123", + name: "test_tool", + arguments: '{"arg1":"value"}', }, - }), - } - }) - - // System prompt and messages - const systemPrompt = "Test system prompt" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }] + }, + ]), + ) - // Start generating a message - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() // Start the generator + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() - // Check that all parameters were passed correctly - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - ) + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ + type: "tool_call", + id: "call_123", + name: "test_tool", + arguments: '{"arg1":"value"}', + }) }) - describe("Native Tool Calling", () => { + it("should include tools in Responses API format", async () => { const testTools = [ { type: "function" as const, function: { name: "test_tool", description: "A test tool", - parameters: { - type: "object", - properties: { - arg1: { type: "string", description: "First argument" }, - }, - required: ["arg1"], - }, + parameters: { type: "object", properties: { arg1: { type: "string" } }, required: ["arg1"] }, }, }, ] - it("should include tools in request when model supports native tools and tools are provided (native is default)", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "test_tool", - }), - }), - ]), - parallel_tool_calls: true, - }), - ) + const stream = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, }) + await stream.next() - it("should include tool_choice when provided", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: [ + expect.objectContaining({ + type: "function", + name: "test_tool", + description: "A test tool", + strict: true, }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, + ], tool_choice: "auto", - }) - await messageGenerator.next() + parallel_tool_calls: true, + }), + ) + }) - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tool_choice: "auto", - }), - ) + it("completePrompt should return text from Responses API", async () => { + const expectedResponse = "This is a test response" + mockResponsesCreate.mockResolvedValueOnce({ + output_text: expectedResponse, }) - it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe(expectedResponse) + }) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - }) - await messageGenerator.next() - - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] - expect(callArgs).toHaveProperty("tools") - expect(callArgs).toHaveProperty("tool_choice") - expect(callArgs).toHaveProperty("parallel_tool_calls", true) + it("should handle errors in completePrompt", async () => { + const errorMessage = "API error" + mockResponsesCreate.mockRejectedValueOnce(new Error(errorMessage)) + + await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) + }) + + it("should include reasoning_effort for mini models", async () => { + const miniModelHandler = new XAIHandler({ + apiModelId: "grok-3-mini", + reasoningEffort: "high", }) - it("should yield tool_call_partial chunks during streaming", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { - name: "test_tool", - arguments: '{"arg1":', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) - const stream = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) + const stream = miniModelHandler.createMessage("test prompt", []) + await stream.next() - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ + reasoning: expect.objectContaining({ + reasoning_effort: "high", + }), + }), + ) + }) - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "test_tool", - arguments: '{"arg1":', - }) - - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', - }) + it("should not include reasoning for non-mini models", async () => { + const regularHandler = new XAIHandler({ + apiModelId: "grok-3", + reasoningEffort: "high", }) - it("should set parallel_tool_calls based on metadata", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - parallelToolCalls: true, - }) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - parallel_tool_calls: true, - }), - ) - }) + const stream = regularHandler.createMessage("test prompt", []) + await stream.next() - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - // Import NativeToolCallParser to set up state - const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser") - - // Clear any previous state - NativeToolCallParser.clearRawChunkState() - - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_xai_test", - function: { - name: "test_tool", - arguments: '{"arg1":"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) - - const chunks = [] - for await (const chunk of stream) { - // Simulate what Task.ts does: when we receive tool_call_partial, - // process it through NativeToolCallParser to populate rawChunkTracker - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } - chunks.push(chunk) - } + const callArgs = mockResponsesCreate.mock.calls[mockResponsesCreate.mock.calls.length - 1][0] + expect(callArgs).not.toHaveProperty("reasoning") + }) - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + it("should handle errors in createMessage", async () => { + const errorMessage = "Stream error" + mockResponsesCreate.mockRejectedValueOnce(new Error(errorMessage)) - expect(partialChunks).toHaveLength(1) - expect(endChunks).toHaveLength(1) - expect(endChunks[0].id).toBe("call_xai_test") - }) + const stream = handler.createMessage("test prompt", []) + await expect(stream.next()).rejects.toThrow(`xAI completion error: ${errorMessage}`) }) }) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index 8b973d41c4e..0cd9cb0273b 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -4,17 +4,18 @@ import OpenAI from "openai" import { type XAIModelId, xaiDefaultModelId, xaiModels, ApiProviderError } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" -import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" import type { ApiHandlerOptions } from "../../shared/api" import { ApiStream } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { convertToResponsesApiInput } from "../transform/responses-api-input" +import { processResponsesApiStream, createUsageNormalizer } from "../transform/responses-api-stream" import { getModelParams } from "../transform/model-params" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { handleOpenAIError } from "./utils/openai-error-handler" +import { isMcpTool } from "../../utils/mcp-name" const XAI_DEFAULT_TEMPERATURE = 0 @@ -53,118 +54,108 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler return { id, info, ...params } } + /** + * Convert tools from OpenAI Chat Completions format to Responses API format. + * Chat Completions: { type: "function", function: { name, description, parameters } } + * Responses API: { type: "function", name, description, parameters } + * + * Uses base provider's convertToolSchemaForOpenAI() for schema hardening + * (additionalProperties: false, ensureAllRequired) and handles MCP tools. + */ + private mapResponseTools(tools?: any[]): any[] | undefined { + const converted = this.convertToolsForOpenAI(tools) + if (!converted?.length) { + return undefined + } + return converted + .filter((tool) => tool?.type === "function") + .map((tool) => { + const isMcp = isMcpTool(tool.function.name) + return { + type: "function", + name: tool.function.name, + description: tool.function.description, + parameters: isMcp + ? tool.function.parameters + : this.convertToolSchemaForOpenAI(tool.function.parameters), + strict: !isMcp, + } + }) + } + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: modelId, info: modelInfo, reasoning } = this.getModel() - - // Use the OpenAI-compatible API. - const requestOptions = { - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE, - messages: [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] as OpenAI.Chat.ChatCompletionMessageParam[], - stream: true as const, - stream_options: { include_usage: true }, - ...(reasoning && reasoning), - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, + const model = this.getModel() + + // Convert directly from Anthropic format to Responses API input format + const input = convertToResponsesApiInput(messages) + const responseTools = this.mapResponseTools(metadata?.tools) + + // Build request options + const requestBody: Record = { + model: model.id, + instructions: systemPrompt, + input: input, + stream: true, + store: false, // Don't store responses server-side for privacy + include: ["reasoning.encrypted_content"], + } + + if (model.maxTokens) { + requestBody.max_output_tokens = model.maxTokens + } + + if (model.temperature !== undefined) { + requestBody.temperature = model.temperature + } + + if (responseTools) { + requestBody.tools = responseTools + // Cast tool_choice since metadata uses Chat Completions types but Responses API has its own type + requestBody.tool_choice = (metadata?.tool_choice ?? "auto") as any + requestBody.parallel_tool_calls = metadata?.parallelToolCalls ?? true } - let stream + // Pass reasoning effort for models that support it (e.g., mini models) + if (model.reasoning) { + requestBody.reasoning = model.reasoning + } + + let stream: AsyncIterable try { - stream = await this.client.chat.completions.create(requestOptions) + stream = (await this.client.responses.create({ + ...requestBody, + stream: true, + } as any)) as unknown as AsyncIterable } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage") + const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "createMessage") TelemetryService.instance.captureException(apiError) throw handleOpenAIError(error, this.providerName) } - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - const finishReason = chunk.choices[0]?.finish_reason - - if (delta?.content) { - yield { - type: "text", - text: delta.content, - } - } - - if (delta && "reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: delta.reasoning_content as string, - } - } - - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - - // Process finish_reason to emit tool_call_end events - // This ensures tool calls are finalized even if the stream doesn't properly close - if (finishReason) { - const endEvents = NativeToolCallParser.processFinishReason(finishReason) - for (const event of endEvents) { - yield event - } - } - - if (chunk.usage) { - // Extract detailed token information if available - // First check for prompt_tokens_details structure (real API response) - const promptDetails = "prompt_tokens_details" in chunk.usage ? chunk.usage.prompt_tokens_details : null - const cachedTokens = promptDetails && "cached_tokens" in promptDetails ? promptDetails.cached_tokens : 0 - - // Fall back to direct fields in usage (used in test mocks) - const readTokens = - cachedTokens || - ("cache_read_input_tokens" in chunk.usage ? (chunk.usage as any).cache_read_input_tokens : 0) - const writeTokens = - "cache_creation_input_tokens" in chunk.usage ? (chunk.usage as any).cache_creation_input_tokens : 0 - - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - cacheReadTokens: readTokens, - cacheWriteTokens: writeTokens, - } - } - } + const normalizeUsage = createUsageNormalizer() + yield* processResponsesApiStream(stream, normalizeUsage) } async completePrompt(prompt: string): Promise { - const { id: modelId, reasoning } = this.getModel() + const model = this.getModel() try { - const response = await this.client.chat.completions.create({ - model: modelId, - messages: [{ role: "user", content: prompt }], - ...(reasoning && reasoning), + const response = await this.client.responses.create({ + model: model.id, + input: [{ role: "user", content: [{ type: "input_text", text: prompt }] }], + store: false, }) - return response.choices[0]?.message.content || "" + // output_text is a convenience field on the Responses API response + return response.output_text || "" } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "completePrompt") + const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "completePrompt") TelemetryService.instance.captureException(apiError) throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/transform/__tests__/responses-api-input.spec.ts b/src/api/transform/__tests__/responses-api-input.spec.ts new file mode 100644 index 00000000000..c57b61d8958 --- /dev/null +++ b/src/api/transform/__tests__/responses-api-input.spec.ts @@ -0,0 +1,338 @@ +import type { Anthropic } from "@anthropic-ai/sdk" +import { convertToResponsesApiInput } from "../responses-api-input" + +describe("convertToResponsesApiInput", () => { + it("should return empty array for empty messages", () => { + expect(convertToResponsesApiInput([])).toEqual([]) + }) + + describe("string content messages", () => { + it("should convert string content to input_text", () => { + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([{ role: "user", content: [{ type: "input_text", text: "Hello" }] }]) + }) + + it("should convert assistant string content to output_text message format", () => { + const messages: Anthropic.Messages.MessageParam[] = [{ role: "assistant", content: "Hi there" }] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Hi there" }], + }, + ]) + }) + }) + + describe("user messages with content blocks", () => { + it("should convert text blocks to input_text", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "What is this?" }], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([{ role: "user", content: [{ type: "input_text", text: "What is this?" }] }]) + }) + + it("should convert image blocks to input_image", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "image", + source: { type: "base64", media_type: "image/png", data: "abc123" }, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + role: "user", + content: [ + { + type: "input_image", + detail: "auto", + image_url: "data:image/png;base64,abc123", + }, + ], + }, + ]) + }) + + it("should convert tool_result to function_call_output", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "tool_123", + content: "Result text", + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call_output", + call_id: "tool_123", + output: "Result text", + }, + ]) + }) + + it("should use (empty) for empty tool_result content", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "tool_123", + content: "", + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call_output", + call_id: "tool_123", + output: "(empty)", + }, + ]) + }) + + it("should extract text from array tool_result content", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "tool_123", + content: [ + { type: "text", text: "Line 1" }, + { type: "text", text: "Line 2" }, + ], + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call_output", + call_id: "tool_123", + output: "Line 1\nLine 2", + }, + ]) + }) + + it("should flush pending user content before tool_result", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { type: "text", text: "Here is context" }, + { + type: "tool_result", + tool_use_id: "tool_123", + content: "Done", + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { role: "user", content: [{ type: "input_text", text: "Here is context" }] }, + { type: "function_call_output", call_id: "tool_123", output: "Done" }, + ]) + }) + }) + + describe("assistant messages with content blocks", () => { + it("should convert text blocks to output_text messages", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [{ type: "text", text: "Here is my response" }], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Here is my response" }], + }, + ]) + }) + + it("should convert tool_use to function_call", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_abc", + name: "read_file", + input: { path: "/tmp/test.txt" }, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call", + call_id: "call_abc", + name: "read_file", + arguments: '{"path":"/tmp/test.txt"}', + }, + ]) + }) + + it("should handle tool_use with string input", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_abc", + name: "run_command", + input: '{"cmd":"ls"}' as any, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result[0]).toEqual( + expect.objectContaining({ + type: "function_call", + arguments: '{"cmd":"ls"}', + }), + ) + }) + + it("should handle mixed text and tool_use in assistant message", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { type: "text", text: "Let me read that file" }, + { + type: "tool_use", + id: "call_abc", + name: "read_file", + input: { path: "/tmp/test.txt" }, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual( + expect.objectContaining({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Let me read that file" }], + }), + ) + expect(result[1]).toEqual( + expect.objectContaining({ + type: "function_call", + name: "read_file", + }), + ) + }) + }) + + describe("multi-turn conversations", () => { + it("should handle a complete tool use cycle", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "Read /tmp/test.txt" }, + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_1", + name: "read_file", + input: { path: "/tmp/test.txt" }, + }, + ], + }, + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "call_1", + content: "file contents here", + }, + ], + }, + { + role: "assistant", + content: [{ type: "text", text: "The file contains: file contents here" }], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toHaveLength(4) + expect(result[0]).toEqual({ role: "user", content: [{ type: "input_text", text: "Read /tmp/test.txt" }] }) + expect(result[1]).toEqual( + expect.objectContaining({ type: "function_call", call_id: "call_1", name: "read_file" }), + ) + expect(result[2]).toEqual( + expect.objectContaining({ + type: "function_call_output", + call_id: "call_1", + output: "file contents here", + }), + ) + expect(result[3]).toEqual( + expect.objectContaining({ + type: "message", + role: "assistant", + }), + ) + }) + }) +}) diff --git a/src/api/transform/__tests__/responses-api-stream.spec.ts b/src/api/transform/__tests__/responses-api-stream.spec.ts new file mode 100644 index 00000000000..6abdbddb659 --- /dev/null +++ b/src/api/transform/__tests__/responses-api-stream.spec.ts @@ -0,0 +1,392 @@ +import { processResponsesApiStream, createUsageNormalizer } from "../responses-api-stream" + +// Helper to create an async iterable from events +async function* mockStream(events: any[]) { + for (const event of events) { + yield event + } +} + +// Helper to collect all chunks from a stream +async function collectChunks(stream: AsyncGenerator) { + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + return chunks +} + +const noopUsage = () => undefined + +describe("processResponsesApiStream", () => { + describe("text deltas", () => { + it("should yield text chunk for response.output_text.delta", async () => { + const stream = mockStream([{ type: "response.output_text.delta", delta: "Hello world" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "text", text: "Hello world" }]) + }) + + it("should yield text chunk for response.text.delta", async () => { + const stream = mockStream([{ type: "response.text.delta", delta: "Hello" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "text", text: "Hello" }]) + }) + + it("should skip text delta with empty delta", async () => { + const stream = mockStream([{ type: "response.output_text.delta", delta: "" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + }) + + describe("reasoning deltas", () => { + it("should yield reasoning chunk for response.reasoning_text.delta", async () => { + const stream = mockStream([{ type: "response.reasoning_text.delta", delta: "Let me think..." }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Let me think..." }]) + }) + + it("should yield reasoning chunk for response.reasoning.delta", async () => { + const stream = mockStream([{ type: "response.reasoning.delta", delta: "Step 1" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Step 1" }]) + }) + + it("should yield reasoning chunk for response.reasoning_summary_text.delta", async () => { + const stream = mockStream([{ type: "response.reasoning_summary_text.delta", delta: "Summary" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Summary" }]) + }) + + it("should yield reasoning chunk for response.reasoning_summary.delta", async () => { + const stream = mockStream([{ type: "response.reasoning_summary.delta", delta: "Summary" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Summary" }]) + }) + }) + + describe("tool calls", () => { + it("should yield tool_call for function_call in output_item.done", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { + type: "function_call", + call_id: "call_123", + name: "read_file", + arguments: '{"path":"/tmp/test.txt"}', + }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([ + { + type: "tool_call", + id: "call_123", + name: "read_file", + arguments: '{"path":"/tmp/test.txt"}', + }, + ]) + }) + + it("should yield tool_call for tool_call type in output_item.done", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { + type: "tool_call", + tool_call_id: "call_456", + name: "write_file", + arguments: '{"path":"/tmp/out.txt"}', + }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([ + { + type: "tool_call", + id: "call_456", + name: "write_file", + arguments: '{"path":"/tmp/out.txt"}', + }, + ]) + }) + + it("should handle object arguments by JSON.stringifying", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { + type: "function_call", + call_id: "call_789", + name: "test", + input: { key: "value" }, + }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks[0].arguments).toBe('{"key":"value"}') + }) + + it("should skip tool_call with missing call_id or name", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { type: "function_call", call_id: "", name: "test", arguments: "{}" }, + }, + { + type: "response.output_item.done", + item: { type: "function_call", call_id: "call_1", name: "", arguments: "{}" }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + + it("should yield tool_call_partial for function_call_arguments.delta", async () => { + const stream = mockStream([ + { + type: "response.function_call_arguments.delta", + call_id: "call_123", + name: "read_file", + delta: '{"path":', + index: 0, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([ + { + type: "tool_call_partial", + index: 0, + id: "call_123", + name: "read_file", + arguments: '{"path":', + }, + ]) + }) + }) + + describe("completion and usage", () => { + it("should yield usage from response.completed", async () => { + const mockNormalize = (usage: any) => ({ + type: "usage" as const, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + + const stream = mockStream([ + { + type: "response.completed", + response: { usage: { input_tokens: 100, output_tokens: 50 } }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, mockNormalize)) + + expect(chunks).toEqual([{ type: "usage", inputTokens: 100, outputTokens: 50 }]) + }) + + it("should yield usage from response.done", async () => { + const mockNormalize = (usage: any) => ({ + type: "usage" as const, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + + const stream = mockStream([ + { + type: "response.done", + response: { usage: { input_tokens: 200, output_tokens: 100 } }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, mockNormalize)) + + expect(chunks).toEqual([{ type: "usage", inputTokens: 200, outputTokens: 100 }]) + }) + + it("should not yield usage when normalizer returns undefined", async () => { + const stream = mockStream([ + { + type: "response.completed", + response: { usage: null }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + }) + + describe("unknown events", () => { + it("should silently ignore unknown event types", async () => { + const stream = mockStream([ + { type: "response.created" }, + { type: "response.in_progress" }, + { type: "response.output_item.added", item: { type: "message" } }, + { type: "response.content_part.added" }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + }) + + describe("full conversation stream", () => { + it("should handle a complete stream with reasoning, text, and usage", async () => { + const mockNormalize = (usage: any) => ({ + type: "usage" as const, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + + const stream = mockStream([ + { type: "response.reasoning_text.delta", delta: "Thinking..." }, + { type: "response.reasoning_text.delta", delta: " done." }, + { type: "response.output_text.delta", delta: "The answer is " }, + { type: "response.output_text.delta", delta: "42." }, + { + type: "response.completed", + response: { usage: { input_tokens: 50, output_tokens: 30 } }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, mockNormalize)) + + expect(chunks).toEqual([ + { type: "reasoning", text: "Thinking..." }, + { type: "reasoning", text: " done." }, + { type: "text", text: "The answer is " }, + { type: "text", text: "42." }, + { type: "usage", inputTokens: 50, outputTokens: 30 }, + ]) + }) + }) +}) + +describe("createUsageNormalizer", () => { + it("should return undefined for null/undefined usage", () => { + const normalize = createUsageNormalizer() + expect(normalize(null)).toBeUndefined() + expect(normalize(undefined)).toBeUndefined() + }) + + it("should extract input and output tokens", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result).toEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + }), + ) + }) + + it("should extract cached tokens from input_tokens_details", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ + input_tokens: 100, + output_tokens: 50, + input_tokens_details: { cached_tokens: 30 }, + }) + + expect(result?.cacheReadTokens).toBe(30) + }) + + it("should extract cache write tokens", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 15, + }) + + expect(result?.cacheWriteTokens).toBe(15) + }) + + it("should extract reasoning tokens from output_tokens_details", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ + input_tokens: 100, + output_tokens: 50, + output_tokens_details: { reasoning_tokens: 20 }, + }) + + expect(result?.reasoningTokens).toBe(20) + }) + + it("should not include reasoningTokens when not present", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result).not.toHaveProperty("reasoningTokens") + }) + + it("should compute totalCost when calculateCost is provided", () => { + const calculateCost = (input: number, output: number, cached: number) => 0.42 + const normalize = createUsageNormalizer(calculateCost) + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result?.totalCost).toBe(0.42) + }) + + it("should not include totalCost when calculateCost is not provided", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result).not.toHaveProperty("totalCost") + }) + + it("should handle Chat Completions style field names as fallback", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ + prompt_tokens: 100, + completion_tokens: 50, + prompt_tokens_details: { cached_tokens: 10 }, + }) + + expect(result).toEqual( + expect.objectContaining({ + inputTokens: 100, + outputTokens: 50, + cacheReadTokens: 10, + }), + ) + }) +}) diff --git a/src/api/transform/responses-api-input.ts b/src/api/transform/responses-api-input.ts new file mode 100644 index 00000000000..a766dfef6ef --- /dev/null +++ b/src/api/transform/responses-api-input.ts @@ -0,0 +1,118 @@ +import { Anthropic } from "@anthropic-ai/sdk" + +/** + * Converts Anthropic-format messages to the OpenAI Responses API input format. + * + * Key differences from Chat Completions format: + * - Content parts use { type: "input_text" } instead of { type: "text" } + * - Images use { type: "input_image" } instead of { type: "image_url" } + * - Tool results use { type: "function_call_output", call_id } instead of { role: "tool", tool_call_id } + * - Tool uses become { type: "function_call", call_id, name, arguments } items + * - System prompt goes via the `instructions` parameter, not as a message + * + * @param messages - Array of Anthropic MessageParam objects + * @returns Array of Responses API input items + */ +export function convertToResponsesApiInput(messages: Anthropic.Messages.MessageParam[]): any[] { + const input: any[] = [] + + for (const message of messages) { + if (typeof message.content === "string") { + if (message.role === "assistant") { + // Assistant messages use output_text in the Responses API format + input.push({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: message.content }], + }) + } else { + input.push({ + role: message.role, + content: [{ type: "input_text", text: message.content }], + }) + } + continue + } + + if (message.role === "assistant") { + for (const part of message.content) { + switch (part.type) { + case "text": + input.push({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: part.text }], + }) + break + case "tool_use": + input.push({ + type: "function_call", + call_id: part.id, + name: part.name, + arguments: typeof part.input === "string" ? part.input : JSON.stringify(part.input ?? {}), + }) + break + case "thinking": + // Include reasoning if it has content + if ((part as any).thinking && (part as any).thinking.trim().length > 0) { + input.push({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: `[Thinking] ${(part as any).thinking}` }], + }) + } + break + } + } + } else { + // User messages + const contentParts: any[] = [] + for (const part of message.content) { + switch (part.type) { + case "text": + contentParts.push({ type: "input_text", text: part.text }) + break + case "image": + contentParts.push({ + type: "input_image", + detail: "auto", + image_url: `data:${part.source.media_type};base64,${part.source.data}`, + }) + break + case "tool_result": { + // Flush any pending user content before the tool result + if (contentParts.length > 0) { + input.push({ role: "user", content: [...contentParts] }) + contentParts.length = 0 + } + // Convert tool result content + let output: string + if (typeof part.content === "string") { + output = part.content || "(empty)" + } else if (Array.isArray(part.content)) { + output = + part.content + .filter((c): c is Anthropic.TextBlockParam => c.type === "text") + .map((c) => c.text) + .join("\n") || "(empty)" + } else { + output = "(empty)" + } + input.push({ + type: "function_call_output", + call_id: part.tool_use_id, + output, + }) + break + } + } + } + // Flush remaining user content + if (contentParts.length > 0) { + input.push({ role: "user", content: contentParts }) + } + } + } + + return input +} diff --git a/src/api/transform/responses-api-stream.ts b/src/api/transform/responses-api-stream.ts new file mode 100644 index 00000000000..884a7551367 --- /dev/null +++ b/src/api/transform/responses-api-stream.ts @@ -0,0 +1,140 @@ +import type { ApiStream, ApiStreamUsageChunk } from "./stream" + +/** + * Processes Responses API stream events and yields ApiStreamChunks. + * + * This is a shared utility for providers that use OpenAI's Responses API + * (POST /v1/responses with stream: true). It handles the core event types: + * + * - Text deltas (response.output_text.delta) + * - Reasoning deltas (response.reasoning_text.delta, response.reasoning_summary_text.delta) + * - Tool/function calls (response.output_item.done with function_call type) + * - Usage data (response.completed) + * + * Provider-specific concerns (WebSocket mode, SSE fallback, duplicate detection, + * pending tool tracking) are intentionally left to individual providers. + * + * @param stream - AsyncIterable of Responses API stream events + * @param normalizeUsage - Provider-specific function to normalize usage data into ApiStreamUsageChunk + */ +export async function* processResponsesApiStream( + stream: AsyncIterable, + normalizeUsage: (usage: any) => ApiStreamUsageChunk | undefined, +): ApiStream { + for await (const event of stream) { + // Text content deltas + if (event?.type === "response.output_text.delta" || event?.type === "response.text.delta") { + if (event?.delta) { + yield { type: "text", text: event.delta } + } + continue + } + + // Reasoning deltas + if ( + event?.type === "response.reasoning_text.delta" || + event?.type === "response.reasoning.delta" || + event?.type === "response.reasoning_summary_text.delta" || + event?.type === "response.reasoning_summary.delta" + ) { + if (event?.delta) { + yield { type: "reasoning", text: event.delta } + } + continue + } + + // Output item events — handle completed function calls and fallback text + if (event?.type === "response.output_item.done") { + const item = event?.item + if (item?.type === "function_call" || item?.type === "tool_call") { + const callId = item.call_id || item.tool_call_id || item.id + const name = item.name || item.function?.name + const argsRaw = item.arguments || item.function?.arguments || item.input + const args = + typeof argsRaw === "string" + ? argsRaw + : argsRaw && typeof argsRaw === "object" + ? JSON.stringify(argsRaw) + : "" + + if (typeof callId === "string" && callId.length > 0 && typeof name === "string" && name.length > 0) { + yield { + type: "tool_call", + id: callId, + name, + arguments: args, + } + } + } + continue + } + + // Function call argument deltas (for streaming tool calls) + if ( + event?.type === "response.function_call_arguments.delta" || + event?.type === "response.tool_call_arguments.delta" + ) { + const callId = event.call_id || event.tool_call_id || event.id || event.item_id + const name = event.name || event.function_name + if (typeof callId === "string" && callId.length > 0) { + yield { + type: "tool_call_partial", + index: event.index ?? 0, + id: callId, + name, + arguments: typeof event.delta === "string" ? event.delta : "", + } + } + continue + } + + // Completion events — extract usage + if (event?.type === "response.completed" || event?.type === "response.done") { + const usage = event?.response?.usage || event?.usage + const usageData = normalizeUsage(usage) + if (usageData) { + yield usageData + } + continue + } + } +} + +/** + * Creates a standard usage normalizer for providers with per-token pricing. + * Extracts input/output tokens, cache tokens, reasoning tokens, and computes cost. + * + * @param calculateCost - Optional function to compute total cost from token counts + */ +export function createUsageNormalizer( + calculateCost?: (inputTokens: number, outputTokens: number, cacheReadTokens: number) => number, +): (usage: any) => ApiStreamUsageChunk | undefined { + return (usage: any): ApiStreamUsageChunk | undefined => { + if (!usage) return undefined + + const inputDetails = usage.input_tokens_details ?? usage.prompt_tokens_details + const cachedTokens = inputDetails?.cached_tokens ?? 0 + + const inputTokens = usage.input_tokens ?? usage.prompt_tokens ?? 0 + const outputTokens = usage.output_tokens ?? usage.completion_tokens ?? 0 + const cacheReadTokens = usage.cache_read_input_tokens ?? cachedTokens ?? 0 + const cacheWriteTokens = usage.cache_creation_input_tokens ?? usage.cache_write_tokens ?? 0 + + const reasoningTokens = + typeof usage.output_tokens_details?.reasoning_tokens === "number" + ? usage.output_tokens_details.reasoning_tokens + : undefined + + const totalCost = calculateCost ? calculateCost(inputTokens, outputTokens, cacheReadTokens) : undefined + + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens, + cacheReadTokens, + ...(typeof reasoningTokens === "number" ? { reasoningTokens } : {}), + ...(typeof totalCost === "number" ? { totalCost } : {}), + } + } +}