From 03fb47c27f6a47ca9be0b41865312355143257c9 Mon Sep 17 00:00:00 2001 From: Enrico Carlesso Date: Thu, 19 Mar 2026 15:52:19 -0700 Subject: [PATCH 1/3] feat: migrate xAI provider to Responses API with reusable transform utilities Migrate the xAI provider from the deprecated Chat Completions API to the Responses API, add grok-4.20 as the new default model, and introduce shared transform utilities for Responses API that other providers can adopt. New shared utilities (src/api/transform/): - responses-api-stream.ts: processResponsesApiStream() handles core Responses API stream events (text, reasoning, tool calls, usage) and createUsageNormalizer() provides configurable token/cost extraction. Designed for reuse by openai-native, openai-codex, or any future Responses API provider. - responses-api-input.ts: convertToResponsesApiInput() converts directly from Anthropic message format to Responses API input format, avoiding the intermediate Chat Completions conversion step. Handles input_text, input_image, function_call, and function_call_output mappings. xAI provider (src/api/providers/xai.ts): - Switch from client.chat.completions.create() to client.responses.create() - Use shared transform utilities for stream handling and input conversion - Enable reasoning traces via include: ["reasoning.encrypted_content"] - System prompt via instructions field, store: false for privacy - completePrompt() also migrated to Responses API Model updates (packages/types/src/providers/xai.ts): - Add grok-4.20 as the new default model (2M context, $2/$6 pricing) - Remove grok-4.20-beta-0309-reasoning and grok-4.20-beta-0309-non-reasoning --- packages/types/src/providers/xai.ts | 20 +- src/api/providers/__tests__/xai.spec.ts | 571 ++++-------------- src/api/providers/xai.ts | 164 ++--- .../__tests__/responses-api-input.spec.ts | 332 ++++++++++ .../__tests__/responses-api-stream.spec.ts | 382 ++++++++++++ src/api/transform/responses-api-input.ts | 109 ++++ src/api/transform/responses-api-stream.ts | 142 +++++ 7 files changed, 1156 insertions(+), 564 deletions(-) create mode 100644 src/api/transform/__tests__/responses-api-input.spec.ts create mode 100644 src/api/transform/__tests__/responses-api-stream.spec.ts create mode 100644 src/api/transform/responses-api-input.ts create mode 100644 src/api/transform/responses-api-stream.ts diff --git a/packages/types/src/providers/xai.ts b/packages/types/src/providers/xai.ts index 755d692f40f..a80e0d73570 100644 --- a/packages/types/src/providers/xai.ts +++ b/packages/types/src/providers/xai.ts @@ -3,10 +3,10 @@ import type { ModelInfo } from "../model.js" // https://docs.x.ai/docs/api-reference export type XAIModelId = keyof typeof xaiModels -export const xaiDefaultModelId: XAIModelId = "grok-4.20-beta-0309-reasoning" +export const xaiDefaultModelId: XAIModelId = "grok-4.20" export const xaiModels = { - "grok-4.20-beta-0309-reasoning": { + "grok-4.20": { maxTokens: 65_536, contextWindow: 2_000_000, supportsImages: true, @@ -15,21 +15,7 @@ export const xaiModels = { outputPrice: 6.0, cacheWritesPrice: 0.5, cacheReadsPrice: 0.5, - description: - "xAI's Grok 4.20 reasoning model with 2M context. Reasoning is internal (not exposed via Chat Completions API).", - includedTools: ["search_replace"], - excludedTools: ["apply_diff"], - }, - "grok-4.20-beta-0309-non-reasoning": { - maxTokens: 65_536, - contextWindow: 2_000_000, - supportsImages: true, - supportsPromptCache: true, - inputPrice: 2.0, - outputPrice: 6.0, - cacheWritesPrice: 0.5, - cacheReadsPrice: 0.5, - description: "xAI's Grok 4.20 non-reasoning model - faster inference with 2M context.", + description: "xAI's flagship Grok 4.20 model with 2M context and reasoning support via Responses API.", includedTools: ["search_replace"], excludedTools: ["apply_diff"], }, diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index c622c9d4fcf..ff4b14cf14a 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -10,14 +10,16 @@ vitest.mock("@roo-code/telemetry", () => ({ }, })) -const mockCreate = vitest.fn() +const mockResponsesCreate = vitest.fn() vitest.mock("openai", () => { const mockConstructor = vitest.fn() return { __esModule: true, - default: mockConstructor.mockImplementation(() => ({ chat: { completions: { create: mockCreate } } })), + default: mockConstructor.mockImplementation(() => ({ + responses: { create: mockResponsesCreate }, + })), } }) @@ -28,16 +30,30 @@ import { xaiDefaultModelId, xaiModels } from "@roo-code/types" import { XAIHandler } from "../xai" +// Helper to create an async iterable from events +function mockStream(events: any[]) { + return { + [Symbol.asyncIterator]: () => { + let index = 0 + return { + async next() { + if (index < events.length) { + return { done: false, value: events[index++] } + } + return { done: true, value: undefined } + }, + } + }, + } +} + describe("XAIHandler", () => { let handler: XAIHandler beforeEach(() => { - // Reset all mocks vi.clearAllMocks() - mockCreate.mockClear() + mockResponsesCreate.mockClear() mockCaptureException.mockClear() - - // Create handler with mock handler = new XAIHandler({}) }) @@ -50,14 +66,9 @@ describe("XAIHandler", () => { }) it("should use the provided API key", () => { - // Clear mocks before this specific test vi.clearAllMocks() - - // Create a handler with our API key const xaiApiKey = "test-api-key" new XAIHandler({ xaiApiKey }) - - // Verify the OpenAI constructor was called with our API key expect(OpenAI).toHaveBeenCalledWith( expect.objectContaining({ apiKey: xaiApiKey, @@ -75,107 +86,37 @@ describe("XAIHandler", () => { const testModelId = "grok-3" const handlerWithModel = new XAIHandler({ apiModelId: testModelId }) const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) expect(model.info).toEqual(xaiModels[testModelId]) }) - it("should include reasoning_effort parameter for mini models", async () => { - const miniModelHandler = new XAIHandler({ - apiModelId: "grok-3-mini", - reasoningEffort: "high", - }) + it("should use Responses API (client.responses.create)", async () => { + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - // Start generating a message - const messageGenerator = miniModelHandler.createMessage("test prompt", []) - await messageGenerator.next() // Start the generator + const stream = handler.createMessage("test prompt", []) + await stream.next() - // Check that reasoning_effort was included - expect(mockCreate).toHaveBeenCalledWith( + expect(mockResponsesCreate).toHaveBeenCalledWith( expect.objectContaining({ - reasoning_effort: "high", + model: xaiDefaultModelId, + instructions: "test prompt", + stream: true, + store: false, + include: ["reasoning.encrypted_content"], }), ) }) - it("should not include reasoning_effort parameter for non-mini models", async () => { - const regularModelHandler = new XAIHandler({ - apiModelId: "grok-3", - reasoningEffort: "high", - }) - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - // Start generating a message - const messageGenerator = regularModelHandler.createMessage("test prompt", []) - await messageGenerator.next() // Start the generator - - // Check call args for reasoning_effort - const calls = mockCreate.mock.calls - const lastCall = calls[calls.length - 1][0] - expect(lastCall).not.toHaveProperty("reasoning_effort") - }) - - it("completePrompt method should return text from OpenAI API", async () => { - const expectedResponse = "This is a test response" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) - - it("should handle errors in completePrompt", async () => { - const errorMessage = "API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - - await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) - }) - it("createMessage should yield text content from stream", async () => { const testContent = "This is test content" - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { content: testContent } }], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + mockResponsesCreate.mockResolvedValueOnce( + mockStream([{ type: "response.output_text.delta", delta: testContent }]), + ) - // Create and consume the stream const stream = handler.createMessage("system prompt", []) const firstChunk = await stream.next() - // Verify the content expect(firstChunk.done).toBe(false) expect(firstChunk.value).toEqual({ type: "text", @@ -186,28 +127,13 @@ describe("XAIHandler", () => { it("createMessage should yield reasoning content from stream", async () => { const testReasoning = "Test reasoning content" - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: { reasoning_content: testReasoning } }], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + mockResponsesCreate.mockResolvedValueOnce( + mockStream([{ type: "response.reasoning_text.delta", delta: testReasoning }]), + ) - // Create and consume the stream const stream = handler.createMessage("system prompt", []) const firstChunk = await stream.next() - // Verify the reasoning content expect(firstChunk.done).toBe(false) expect(firstChunk.value).toEqual({ type: "reasoning", @@ -215,373 +141,126 @@ describe("XAIHandler", () => { }) }) - it("createMessage should yield usage data from stream", async () => { - // Setup mock for streaming response that includes usage data - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [{ delta: {} }], // Needs to have choices array to avoid error - usage: { - prompt_tokens: 10, - completion_tokens: 20, - cache_read_input_tokens: 5, - cache_creation_input_tokens: 15, - }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) + it("createMessage should yield usage data from response.completed", async () => { + mockResponsesCreate.mockResolvedValueOnce( + mockStream([ + { + type: "response.completed", + response: { + usage: { + input_tokens: 10, + output_tokens: 20, + input_tokens_details: { cached_tokens: 5 }, + output_tokens_details: { reasoning_tokens: 8 }, + }, + }, + }, + ]), + ) - // Create and consume the stream const stream = handler.createMessage("system prompt", []) const firstChunk = await stream.next() - // Verify the usage data expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ - type: "usage", - inputTokens: 10, - outputTokens: 20, - cacheReadTokens: 5, - cacheWriteTokens: 15, - }) + expect(firstChunk.value).toEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 10, + outputTokens: 20, + cacheReadTokens: 5, + reasoningTokens: 8, + }), + ) }) - it("createMessage should pass correct parameters to OpenAI client", async () => { - // Setup a handler with specific model - const modelId = "grok-3" - const modelInfo = xaiModels[modelId] - const handlerWithModel = new XAIHandler({ apiModelId: modelId }) - - // Setup mock for streaming response - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } + it("createMessage should yield tool_call from output_item.done", async () => { + mockResponsesCreate.mockResolvedValueOnce( + mockStream([ + { + type: "response.output_item.done", + item: { + type: "function_call", + call_id: "call_123", + name: "test_tool", + arguments: '{"arg1":"value"}', }, - }), - } - }) - - // System prompt and messages - const systemPrompt = "Test system prompt" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }] + }, + ]), + ) - // Start generating a message - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() // Start the generator + const stream = handler.createMessage("system prompt", []) + const firstChunk = await stream.next() - // Check that all parameters were passed correctly - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - ) + expect(firstChunk.done).toBe(false) + expect(firstChunk.value).toEqual({ + type: "tool_call", + id: "call_123", + name: "test_tool", + arguments: '{"arg1":"value"}', + }) }) - describe("Native Tool Calling", () => { + it("should include tools in Responses API format", async () => { const testTools = [ { type: "function" as const, function: { name: "test_tool", description: "A test tool", - parameters: { - type: "object", - properties: { - arg1: { type: "string", description: "First argument" }, - }, - required: ["arg1"], - }, + parameters: { type: "object", properties: { arg1: { type: "string" } }, required: ["arg1"] }, }, }, ] - it("should include tools in request when model supports native tools and tools are provided (native is default)", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tools: expect.arrayContaining([ - expect.objectContaining({ - type: "function", - function: expect.objectContaining({ - name: "test_tool", - }), - }), - ]), - parallel_tool_calls: true, - }), - ) + const stream = handler.createMessage("test prompt", [], { + taskId: "test-task-id", + tools: testTools, }) + await stream.next() - it("should include tool_choice when provided", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: [ + expect.objectContaining({ + type: "function", + name: "test_tool", + description: "A test tool", }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, + ], tool_choice: "auto", - }) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - tool_choice: "auto", - }), - ) - }) - - it("should always include tools and tool_choice (tools are guaranteed to be present after ALWAYS_AVAILABLE_TOOLS)", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - }) - await messageGenerator.next() - - // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0] - expect(callArgs).toHaveProperty("tools") - expect(callArgs).toHaveProperty("tool_choice") - expect(callArgs).toHaveProperty("parallel_tool_calls", true) - }) - - it("should yield tool_call_partial chunks during streaming", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_123", - function: { - name: "test_tool", - arguments: '{"arg1":', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - function: { - arguments: '"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) - - const chunks = [] - for await (const chunk of stream) { - chunks.push(chunk) - } + }), + ) + }) - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: "call_123", - name: "test_tool", - arguments: '{"arg1":', - }) - - expect(chunks).toContainEqual({ - type: "tool_call_partial", - index: 0, - id: undefined, - name: undefined, - arguments: '"value"}', - }) + it("completePrompt should return text from Responses API", async () => { + const expectedResponse = "This is a test response" + mockResponsesCreate.mockResolvedValueOnce({ + output: [ + { + type: "message", + content: [{ type: "output_text", text: expectedResponse }], + }, + ], }) - it("should set parallel_tool_calls based on metadata", async () => { - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) + const result = await handler.completePrompt("test prompt") + expect(result).toBe(expectedResponse) + }) - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const messageGenerator = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - parallelToolCalls: true, - }) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - parallel_tool_calls: true, - }), - ) - }) + it("should handle errors in completePrompt", async () => { + const errorMessage = "API error" + mockResponsesCreate.mockRejectedValueOnce(new Error(errorMessage)) - it("should yield tool_call_end events when finish_reason is tool_calls", async () => { - // Import NativeToolCallParser to set up state - const { NativeToolCallParser } = await import("../../../core/assistant-message/NativeToolCallParser") - - // Clear any previous state - NativeToolCallParser.clearRawChunkState() - - const handlerWithTools = new XAIHandler({ apiModelId: "grok-3" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: vi - .fn() - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: { - tool_calls: [ - { - index: 0, - id: "call_xai_test", - function: { - name: "test_tool", - arguments: '{"arg1":"value"}', - }, - }, - ], - }, - }, - ], - }, - }) - .mockResolvedValueOnce({ - done: false, - value: { - choices: [ - { - delta: {}, - finish_reason: "tool_calls", - }, - ], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, - }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handlerWithTools.createMessage("test prompt", [], { - taskId: "test-task-id", - tools: testTools, - }) - - const chunks = [] - for await (const chunk of stream) { - // Simulate what Task.ts does: when we receive tool_call_partial, - // process it through NativeToolCallParser to populate rawChunkTracker - if (chunk.type === "tool_call_partial") { - NativeToolCallParser.processRawChunk({ - index: chunk.index, - id: chunk.id, - name: chunk.name, - arguments: chunk.arguments, - }) - } - chunks.push(chunk) - } + await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) + }) - // Should have tool_call_partial and tool_call_end - const partialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial") - const endChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + it("should handle errors in createMessage", async () => { + const errorMessage = "Stream error" + mockResponsesCreate.mockRejectedValueOnce(new Error(errorMessage)) - expect(partialChunks).toHaveLength(1) - expect(endChunks).toHaveLength(1) - expect(endChunks[0].id).toBe("call_xai_test") - }) + const stream = handler.createMessage("test prompt", []) + await expect(stream.next()).rejects.toThrow(`xAI completion error: ${errorMessage}`) }) }) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index 8b973d41c4e..a750dfbeef5 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -4,12 +4,11 @@ import OpenAI from "openai" import { type XAIModelId, xaiDefaultModelId, xaiModels, ApiProviderError } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" -import { NativeToolCallParser } from "../../core/assistant-message/NativeToolCallParser" import type { ApiHandlerOptions } from "../../shared/api" import { ApiStream } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" -import { getModelParams } from "../transform/model-params" +import { convertToResponsesApiInput } from "../transform/responses-api-input" +import { processResponsesApiStream, createUsageNormalizer } from "../transform/responses-api-stream" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" @@ -42,15 +41,27 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler ? (this.options.apiModelId as XAIModelId) : xaiDefaultModelId - const info = xaiModels[id] - const params = getModelParams({ - format: "openai", - modelId: id, - model: info, - settings: this.options, - defaultTemperature: XAI_DEFAULT_TEMPERATURE, - }) - return { id, info, ...params } + return { id, info: xaiModels[id] } + } + + /** + * Convert tools from OpenAI Chat Completions format to Responses API format. + * Chat Completions: { type: "function", function: { name, description, parameters } } + * Responses API: { type: "function", name, description, parameters } + */ + private mapResponseTools(tools?: any[]): any[] | undefined { + if (!tools?.length) { + return undefined + } + return tools + .filter((tool) => tool?.type === "function") + .map((tool) => ({ + type: "function", + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters ?? null, + strict: false, + })) } override async *createMessage( @@ -58,113 +69,64 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: modelId, info: modelInfo, reasoning } = this.getModel() - - // Use the OpenAI-compatible API. - const requestOptions = { - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE, - messages: [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] as OpenAI.Chat.ChatCompletionMessageParam[], - stream: true as const, - stream_options: { include_usage: true }, - ...(reasoning && reasoning), - tools: this.convertToolsForOpenAI(metadata?.tools), - tool_choice: metadata?.tool_choice, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, - } + const model = this.getModel() + + // Convert directly from Anthropic format to Responses API input format + const input = convertToResponsesApiInput(messages) + const responseTools = this.mapResponseTools(metadata?.tools) let stream try { - stream = await this.client.chat.completions.create(requestOptions) + stream = await this.client.responses.create({ + model: model.id, + instructions: systemPrompt, + input: input, + max_output_tokens: model.info.maxTokens, + temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE, + stream: true, + store: false, // Don't store responses server-side for privacy + tools: responseTools, + tool_choice: responseTools ? "auto" : undefined, + include: ["reasoning.encrypted_content"], + }) } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage") + const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "createMessage") TelemetryService.instance.captureException(apiError) throw handleOpenAIError(error, this.providerName) } - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - const finishReason = chunk.choices[0]?.finish_reason - - if (delta?.content) { - yield { - type: "text", - text: delta.content, - } - } - - if (delta && "reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: delta.reasoning_content as string, - } - } - - // Handle tool calls in stream - emit partial chunks for NativeToolCallParser - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } - - // Process finish_reason to emit tool_call_end events - // This ensures tool calls are finalized even if the stream doesn't properly close - if (finishReason) { - const endEvents = NativeToolCallParser.processFinishReason(finishReason) - for (const event of endEvents) { - yield event - } - } - - if (chunk.usage) { - // Extract detailed token information if available - // First check for prompt_tokens_details structure (real API response) - const promptDetails = "prompt_tokens_details" in chunk.usage ? chunk.usage.prompt_tokens_details : null - const cachedTokens = promptDetails && "cached_tokens" in promptDetails ? promptDetails.cached_tokens : 0 - - // Fall back to direct fields in usage (used in test mocks) - const readTokens = - cachedTokens || - ("cache_read_input_tokens" in chunk.usage ? (chunk.usage as any).cache_read_input_tokens : 0) - const writeTokens = - "cache_creation_input_tokens" in chunk.usage ? (chunk.usage as any).cache_creation_input_tokens : 0 - - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - cacheReadTokens: readTokens, - cacheWriteTokens: writeTokens, - } - } - } + const normalizeUsage = createUsageNormalizer(model.info) + yield* processResponsesApiStream(stream, normalizeUsage) } async completePrompt(prompt: string): Promise { - const { id: modelId, reasoning } = this.getModel() + const model = this.getModel() try { - const response = await this.client.chat.completions.create({ - model: modelId, - messages: [{ role: "user", content: prompt }], - ...(reasoning && reasoning), + const response = await this.client.responses.create({ + model: model.id, + input: [{ role: "user", content: [{ type: "input_text", text: prompt }] }], + store: false, }) - return response.choices[0]?.message.content || "" + // Extract text from the response output + const output = (response as any).output + if (Array.isArray(output)) { + for (const item of output) { + if (item.type === "message" && Array.isArray(item.content)) { + for (const content of item.content) { + if (content.type === "output_text" && content.text) { + return content.text + } + } + } + } + } + return (response as any).output_text || "" } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "completePrompt") + const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "completePrompt") TelemetryService.instance.captureException(apiError) throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/transform/__tests__/responses-api-input.spec.ts b/src/api/transform/__tests__/responses-api-input.spec.ts new file mode 100644 index 00000000000..c78345c7897 --- /dev/null +++ b/src/api/transform/__tests__/responses-api-input.spec.ts @@ -0,0 +1,332 @@ +import type { Anthropic } from "@anthropic-ai/sdk" +import { convertToResponsesApiInput } from "../responses-api-input" + +describe("convertToResponsesApiInput", () => { + it("should return empty array for empty messages", () => { + expect(convertToResponsesApiInput([])).toEqual([]) + }) + + describe("string content messages", () => { + it("should convert string content to input_text", () => { + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([{ role: "user", content: [{ type: "input_text", text: "Hello" }] }]) + }) + + it("should convert assistant string content", () => { + const messages: Anthropic.Messages.MessageParam[] = [{ role: "assistant", content: "Hi there" }] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([{ role: "assistant", content: [{ type: "input_text", text: "Hi there" }] }]) + }) + }) + + describe("user messages with content blocks", () => { + it("should convert text blocks to input_text", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "What is this?" }], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([{ role: "user", content: [{ type: "input_text", text: "What is this?" }] }]) + }) + + it("should convert image blocks to input_image", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "image", + source: { type: "base64", media_type: "image/png", data: "abc123" }, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + role: "user", + content: [ + { + type: "input_image", + detail: "auto", + image_url: "data:image/png;base64,abc123", + }, + ], + }, + ]) + }) + + it("should convert tool_result to function_call_output", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "tool_123", + content: "Result text", + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call_output", + call_id: "tool_123", + output: "Result text", + }, + ]) + }) + + it("should use (empty) for empty tool_result content", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "tool_123", + content: "", + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call_output", + call_id: "tool_123", + output: "(empty)", + }, + ]) + }) + + it("should extract text from array tool_result content", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "tool_123", + content: [ + { type: "text", text: "Line 1" }, + { type: "text", text: "Line 2" }, + ], + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call_output", + call_id: "tool_123", + output: "Line 1\nLine 2", + }, + ]) + }) + + it("should flush pending user content before tool_result", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { type: "text", text: "Here is context" }, + { + type: "tool_result", + tool_use_id: "tool_123", + content: "Done", + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { role: "user", content: [{ type: "input_text", text: "Here is context" }] }, + { type: "function_call_output", call_id: "tool_123", output: "Done" }, + ]) + }) + }) + + describe("assistant messages with content blocks", () => { + it("should convert text blocks to output_text messages", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [{ type: "text", text: "Here is my response" }], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Here is my response" }], + }, + ]) + }) + + it("should convert tool_use to function_call", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_abc", + name: "read_file", + input: { path: "/tmp/test.txt" }, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toEqual([ + { + type: "function_call", + call_id: "call_abc", + name: "read_file", + arguments: '{"path":"/tmp/test.txt"}', + }, + ]) + }) + + it("should handle tool_use with string input", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_abc", + name: "run_command", + input: '{"cmd":"ls"}' as any, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result[0]).toEqual( + expect.objectContaining({ + type: "function_call", + arguments: '{"cmd":"ls"}', + }), + ) + }) + + it("should handle mixed text and tool_use in assistant message", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "assistant", + content: [ + { type: "text", text: "Let me read that file" }, + { + type: "tool_use", + id: "call_abc", + name: "read_file", + input: { path: "/tmp/test.txt" }, + }, + ], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toHaveLength(2) + expect(result[0]).toEqual( + expect.objectContaining({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Let me read that file" }], + }), + ) + expect(result[1]).toEqual( + expect.objectContaining({ + type: "function_call", + name: "read_file", + }), + ) + }) + }) + + describe("multi-turn conversations", () => { + it("should handle a complete tool use cycle", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "Read /tmp/test.txt" }, + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_1", + name: "read_file", + input: { path: "/tmp/test.txt" }, + }, + ], + }, + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "call_1", + content: "file contents here", + }, + ], + }, + { + role: "assistant", + content: [{ type: "text", text: "The file contains: file contents here" }], + }, + ] + + const result = convertToResponsesApiInput(messages) + + expect(result).toHaveLength(4) + expect(result[0]).toEqual({ role: "user", content: [{ type: "input_text", text: "Read /tmp/test.txt" }] }) + expect(result[1]).toEqual( + expect.objectContaining({ type: "function_call", call_id: "call_1", name: "read_file" }), + ) + expect(result[2]).toEqual( + expect.objectContaining({ + type: "function_call_output", + call_id: "call_1", + output: "file contents here", + }), + ) + expect(result[3]).toEqual( + expect.objectContaining({ + type: "message", + role: "assistant", + }), + ) + }) + }) +}) diff --git a/src/api/transform/__tests__/responses-api-stream.spec.ts b/src/api/transform/__tests__/responses-api-stream.spec.ts new file mode 100644 index 00000000000..4d90ebaacb1 --- /dev/null +++ b/src/api/transform/__tests__/responses-api-stream.spec.ts @@ -0,0 +1,382 @@ +import { processResponsesApiStream, createUsageNormalizer } from "../responses-api-stream" + +// Helper to create an async iterable from events +async function* mockStream(events: any[]) { + for (const event of events) { + yield event + } +} + +// Helper to collect all chunks from a stream +async function collectChunks(stream: AsyncGenerator) { + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + return chunks +} + +const noopUsage = () => undefined + +describe("processResponsesApiStream", () => { + describe("text deltas", () => { + it("should yield text chunk for response.output_text.delta", async () => { + const stream = mockStream([{ type: "response.output_text.delta", delta: "Hello world" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "text", text: "Hello world" }]) + }) + + it("should yield text chunk for response.text.delta", async () => { + const stream = mockStream([{ type: "response.text.delta", delta: "Hello" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "text", text: "Hello" }]) + }) + + it("should skip text delta with empty delta", async () => { + const stream = mockStream([{ type: "response.output_text.delta", delta: "" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + }) + + describe("reasoning deltas", () => { + it("should yield reasoning chunk for response.reasoning_text.delta", async () => { + const stream = mockStream([{ type: "response.reasoning_text.delta", delta: "Let me think..." }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Let me think..." }]) + }) + + it("should yield reasoning chunk for response.reasoning.delta", async () => { + const stream = mockStream([{ type: "response.reasoning.delta", delta: "Step 1" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Step 1" }]) + }) + + it("should yield reasoning chunk for response.reasoning_summary_text.delta", async () => { + const stream = mockStream([{ type: "response.reasoning_summary_text.delta", delta: "Summary" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Summary" }]) + }) + + it("should yield reasoning chunk for response.reasoning_summary.delta", async () => { + const stream = mockStream([{ type: "response.reasoning_summary.delta", delta: "Summary" }]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([{ type: "reasoning", text: "Summary" }]) + }) + }) + + describe("tool calls", () => { + it("should yield tool_call for function_call in output_item.done", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { + type: "function_call", + call_id: "call_123", + name: "read_file", + arguments: '{"path":"/tmp/test.txt"}', + }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([ + { + type: "tool_call", + id: "call_123", + name: "read_file", + arguments: '{"path":"/tmp/test.txt"}', + }, + ]) + }) + + it("should yield tool_call for tool_call type in output_item.done", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { + type: "tool_call", + tool_call_id: "call_456", + name: "write_file", + arguments: '{"path":"/tmp/out.txt"}', + }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([ + { + type: "tool_call", + id: "call_456", + name: "write_file", + arguments: '{"path":"/tmp/out.txt"}', + }, + ]) + }) + + it("should handle object arguments by JSON.stringifying", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { + type: "function_call", + call_id: "call_789", + name: "test", + input: { key: "value" }, + }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks[0].arguments).toBe('{"key":"value"}') + }) + + it("should skip tool_call with missing call_id or name", async () => { + const stream = mockStream([ + { + type: "response.output_item.done", + item: { type: "function_call", call_id: "", name: "test", arguments: "{}" }, + }, + { + type: "response.output_item.done", + item: { type: "function_call", call_id: "call_1", name: "", arguments: "{}" }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + + it("should yield tool_call_partial for function_call_arguments.delta", async () => { + const stream = mockStream([ + { + type: "response.function_call_arguments.delta", + call_id: "call_123", + name: "read_file", + delta: '{"path":', + index: 0, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([ + { + type: "tool_call_partial", + index: 0, + id: "call_123", + name: "read_file", + arguments: '{"path":', + }, + ]) + }) + }) + + describe("completion and usage", () => { + it("should yield usage from response.completed", async () => { + const mockNormalize = (usage: any) => ({ + type: "usage" as const, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + + const stream = mockStream([ + { + type: "response.completed", + response: { usage: { input_tokens: 100, output_tokens: 50 } }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, mockNormalize)) + + expect(chunks).toEqual([{ type: "usage", inputTokens: 100, outputTokens: 50 }]) + }) + + it("should yield usage from response.done", async () => { + const mockNormalize = (usage: any) => ({ + type: "usage" as const, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + + const stream = mockStream([ + { + type: "response.done", + response: { usage: { input_tokens: 200, output_tokens: 100 } }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, mockNormalize)) + + expect(chunks).toEqual([{ type: "usage", inputTokens: 200, outputTokens: 100 }]) + }) + + it("should not yield usage when normalizer returns undefined", async () => { + const stream = mockStream([ + { + type: "response.completed", + response: { usage: null }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + }) + + describe("unknown events", () => { + it("should silently ignore unknown event types", async () => { + const stream = mockStream([ + { type: "response.created" }, + { type: "response.in_progress" }, + { type: "response.output_item.added", item: { type: "message" } }, + { type: "response.content_part.added" }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, noopUsage)) + + expect(chunks).toEqual([]) + }) + }) + + describe("full conversation stream", () => { + it("should handle a complete stream with reasoning, text, and usage", async () => { + const mockNormalize = (usage: any) => ({ + type: "usage" as const, + inputTokens: usage.input_tokens, + outputTokens: usage.output_tokens, + }) + + const stream = mockStream([ + { type: "response.reasoning_text.delta", delta: "Thinking..." }, + { type: "response.reasoning_text.delta", delta: " done." }, + { type: "response.output_text.delta", delta: "The answer is " }, + { type: "response.output_text.delta", delta: "42." }, + { + type: "response.completed", + response: { usage: { input_tokens: 50, output_tokens: 30 } }, + }, + ]) + + const chunks = await collectChunks(processResponsesApiStream(stream, mockNormalize)) + + expect(chunks).toEqual([ + { type: "reasoning", text: "Thinking..." }, + { type: "reasoning", text: " done." }, + { type: "text", text: "The answer is " }, + { type: "text", text: "42." }, + { type: "usage", inputTokens: 50, outputTokens: 30 }, + ]) + }) + }) +}) + +describe("createUsageNormalizer", () => { + const mockModelInfo = { contextWindow: 128000, supportsPromptCache: false } as any + + it("should return undefined for null/undefined usage", () => { + const normalize = createUsageNormalizer(mockModelInfo) + expect(normalize(null)).toBeUndefined() + expect(normalize(undefined)).toBeUndefined() + }) + + it("should extract input and output tokens", () => { + const normalize = createUsageNormalizer(mockModelInfo) + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result).toEqual( + expect.objectContaining({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + }), + ) + }) + + it("should extract cached tokens from input_tokens_details", () => { + const normalize = createUsageNormalizer(mockModelInfo) + + const result = normalize({ + input_tokens: 100, + output_tokens: 50, + input_tokens_details: { cached_tokens: 30 }, + }) + + expect(result?.cacheReadTokens).toBe(30) + }) + + it("should extract reasoning tokens from output_tokens_details", () => { + const normalize = createUsageNormalizer(mockModelInfo) + + const result = normalize({ + input_tokens: 100, + output_tokens: 50, + output_tokens_details: { reasoning_tokens: 20 }, + }) + + expect(result?.reasoningTokens).toBe(20) + }) + + it("should not include reasoningTokens when not present", () => { + const normalize = createUsageNormalizer(mockModelInfo) + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result).not.toHaveProperty("reasoningTokens") + }) + + it("should compute totalCost when calculateCost is provided", () => { + const calculateCost = (input: number, output: number, cached: number) => 0.42 + const normalize = createUsageNormalizer(mockModelInfo, calculateCost) + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result?.totalCost).toBe(0.42) + }) + + it("should not include totalCost when calculateCost is not provided", () => { + const normalize = createUsageNormalizer(mockModelInfo) + + const result = normalize({ input_tokens: 100, output_tokens: 50 }) + + expect(result).not.toHaveProperty("totalCost") + }) + + it("should handle Chat Completions style field names as fallback", () => { + const normalize = createUsageNormalizer(mockModelInfo) + + const result = normalize({ + prompt_tokens: 100, + completion_tokens: 50, + prompt_tokens_details: { cached_tokens: 10 }, + }) + + expect(result).toEqual( + expect.objectContaining({ + inputTokens: 100, + outputTokens: 50, + cacheReadTokens: 10, + }), + ) + }) +}) diff --git a/src/api/transform/responses-api-input.ts b/src/api/transform/responses-api-input.ts new file mode 100644 index 00000000000..bfcc52b5425 --- /dev/null +++ b/src/api/transform/responses-api-input.ts @@ -0,0 +1,109 @@ +import { Anthropic } from "@anthropic-ai/sdk" + +/** + * Converts Anthropic-format messages to the OpenAI Responses API input format. + * + * Key differences from Chat Completions format: + * - Content parts use { type: "input_text" } instead of { type: "text" } + * - Images use { type: "input_image" } instead of { type: "image_url" } + * - Tool results use { type: "function_call_output", call_id } instead of { role: "tool", tool_call_id } + * - Tool uses become { type: "function_call", call_id, name, arguments } items + * - System prompt goes via the `instructions` parameter, not as a message + * + * @param messages - Array of Anthropic MessageParam objects + * @returns Array of Responses API input items + */ +export function convertToResponsesApiInput(messages: Anthropic.Messages.MessageParam[]): any[] { + const input: any[] = [] + + for (const message of messages) { + if (typeof message.content === "string") { + input.push({ + role: message.role, + content: [{ type: "input_text", text: message.content }], + }) + continue + } + + if (message.role === "assistant") { + for (const part of message.content) { + switch (part.type) { + case "text": + input.push({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: part.text }], + }) + break + case "tool_use": + input.push({ + type: "function_call", + call_id: part.id, + name: part.name, + arguments: typeof part.input === "string" ? part.input : JSON.stringify(part.input ?? {}), + }) + break + case "thinking": + // Include reasoning if it has content + if ((part as any).thinking && (part as any).thinking.trim().length > 0) { + input.push({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: `[Thinking] ${(part as any).thinking}` }], + }) + } + break + } + } + } else { + // User messages + const contentParts: any[] = [] + for (const part of message.content) { + switch (part.type) { + case "text": + contentParts.push({ type: "input_text", text: part.text }) + break + case "image": + contentParts.push({ + type: "input_image", + detail: "auto", + image_url: `data:${part.source.media_type};base64,${part.source.data}`, + }) + break + case "tool_result": { + // Flush any pending user content before the tool result + if (contentParts.length > 0) { + input.push({ role: "user", content: [...contentParts] }) + contentParts.length = 0 + } + // Convert tool result content + let output: string + if (typeof part.content === "string") { + output = part.content || "(empty)" + } else if (Array.isArray(part.content)) { + output = + part.content + .filter((c): c is Anthropic.TextBlockParam => c.type === "text") + .map((c) => c.text) + .join("\n") || "(empty)" + } else { + output = "(empty)" + } + input.push({ + type: "function_call_output", + call_id: part.tool_use_id, + output, + }) + break + } + } + } + // Flush remaining user content + if (contentParts.length > 0) { + input.push({ role: "user", content: contentParts }) + } + } + } + + return input +} diff --git a/src/api/transform/responses-api-stream.ts b/src/api/transform/responses-api-stream.ts new file mode 100644 index 00000000000..ab0cdfbf94d --- /dev/null +++ b/src/api/transform/responses-api-stream.ts @@ -0,0 +1,142 @@ +import type { ModelInfo } from "@roo-code/types" + +import type { ApiStream, ApiStreamUsageChunk } from "./stream" + +/** + * Processes Responses API stream events and yields ApiStreamChunks. + * + * This is a shared utility for providers that use OpenAI's Responses API + * (POST /v1/responses with stream: true). It handles the core event types: + * + * - Text deltas (response.output_text.delta) + * - Reasoning deltas (response.reasoning_text.delta, response.reasoning_summary_text.delta) + * - Tool/function calls (response.output_item.done with function_call type) + * - Usage data (response.completed) + * + * Provider-specific concerns (WebSocket mode, SSE fallback, duplicate detection, + * pending tool tracking) are intentionally left to individual providers. + * + * @param stream - AsyncIterable of Responses API stream events + * @param normalizeUsage - Provider-specific function to normalize usage data into ApiStreamUsageChunk + */ +export async function* processResponsesApiStream( + stream: AsyncIterable, + normalizeUsage: (usage: any) => ApiStreamUsageChunk | undefined, +): ApiStream { + for await (const event of stream) { + // Text content deltas + if (event?.type === "response.output_text.delta" || event?.type === "response.text.delta") { + if (event?.delta) { + yield { type: "text", text: event.delta } + } + continue + } + + // Reasoning deltas + if ( + event?.type === "response.reasoning_text.delta" || + event?.type === "response.reasoning.delta" || + event?.type === "response.reasoning_summary_text.delta" || + event?.type === "response.reasoning_summary.delta" + ) { + if (event?.delta) { + yield { type: "reasoning", text: event.delta } + } + continue + } + + // Output item events — handle completed function calls and fallback text + if (event?.type === "response.output_item.done") { + const item = event?.item + if (item?.type === "function_call" || item?.type === "tool_call") { + const callId = item.call_id || item.tool_call_id || item.id + const name = item.name || item.function?.name + const argsRaw = item.arguments || item.function?.arguments || item.input + const args = + typeof argsRaw === "string" + ? argsRaw + : argsRaw && typeof argsRaw === "object" + ? JSON.stringify(argsRaw) + : "" + + if (typeof callId === "string" && callId.length > 0 && typeof name === "string" && name.length > 0) { + yield { + type: "tool_call", + id: callId, + name, + arguments: args, + } + } + } + continue + } + + // Function call argument deltas (for streaming tool calls) + if ( + event?.type === "response.function_call_arguments.delta" || + event?.type === "response.tool_call_arguments.delta" + ) { + const callId = event.call_id || event.tool_call_id || event.id || event.item_id + const name = event.name || event.function_name + if (typeof callId === "string" && callId.length > 0) { + yield { + type: "tool_call_partial", + index: event.index ?? 0, + id: callId, + name, + arguments: typeof event.delta === "string" ? event.delta : "", + } + } + continue + } + + // Completion events — extract usage + if (event?.type === "response.completed" || event?.type === "response.done") { + const usage = event?.response?.usage || event?.usage + const usageData = normalizeUsage(usage) + if (usageData) { + yield usageData + } + continue + } + } +} + +/** + * Creates a standard usage normalizer for providers with per-token pricing. + * Extracts input/output tokens, cache tokens, reasoning tokens, and computes cost. + * + * @param modelInfo - Model info with pricing details + * @param calculateCost - Optional function to compute total cost from token counts + */ +export function createUsageNormalizer( + modelInfo: ModelInfo, + calculateCost?: (inputTokens: number, outputTokens: number, cacheReadTokens: number) => number, +): (usage: any) => ApiStreamUsageChunk | undefined { + return (usage: any): ApiStreamUsageChunk | undefined => { + if (!usage) return undefined + + const inputDetails = usage.input_tokens_details ?? usage.prompt_tokens_details + const cachedTokens = inputDetails?.cached_tokens ?? 0 + + const inputTokens = usage.input_tokens ?? usage.prompt_tokens ?? 0 + const outputTokens = usage.output_tokens ?? usage.completion_tokens ?? 0 + const cacheReadTokens = usage.cache_read_input_tokens ?? cachedTokens ?? 0 + + const reasoningTokens = + typeof usage.output_tokens_details?.reasoning_tokens === "number" + ? usage.output_tokens_details.reasoning_tokens + : undefined + + const totalCost = calculateCost ? calculateCost(inputTokens, outputTokens, cacheReadTokens) : undefined + + return { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + ...(typeof reasoningTokens === "number" ? { reasoningTokens } : {}), + ...(typeof totalCost === "number" ? { totalCost } : {}), + } + } +} From 059b0db6c32eb87ca72c06158f4afe0b80dbf99d Mon Sep 17 00:00:00 2001 From: Roo Code Date: Fri, 20 Mar 2026 00:02:42 +0000 Subject: [PATCH 2/3] fix: address review issues in xAI Responses API migration - Use base provider convertToolSchemaForOpenAI() for tool schema hardening - Handle MCP tools with isMcpTool() (strict: false for MCP, true otherwise) - Respect metadata.tool_choice instead of always using "auto" - Add parallel_tool_calls pass-through from metadata - Fix assistant string content to use output_text format (not input_text) - Remove unused modelInfo param from createUsageNormalizer - Add cacheWriteTokens extraction to usage normalizer - Fix test() -> it() inconsistency in xai.spec.ts - Update tests to match all changes --- src/api/providers/__tests__/xai.spec.ts | 4 ++- src/api/providers/xai.ts | 29 ++++++++++++------ .../__tests__/responses-api-input.spec.ts | 10 +++++-- .../__tests__/responses-api-stream.spec.ts | 30 ++++++++++++------- src/api/transform/responses-api-input.ts | 17 ++++++++--- src/api/transform/responses-api-stream.ts | 6 ++-- 6 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index ff4b14cf14a..d760c0fd473 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -82,7 +82,7 @@ describe("XAIHandler", () => { expect(model.info).toEqual(xaiModels[xaiDefaultModelId]) }) - test("should return specified model when valid model is provided", () => { + it("should return specified model when valid model is provided", () => { const testModelId = "grok-3" const handlerWithModel = new XAIHandler({ apiModelId: testModelId }) const model = handlerWithModel.getModel() @@ -227,9 +227,11 @@ describe("XAIHandler", () => { type: "function", name: "test_tool", description: "A test tool", + strict: true, }), ], tool_choice: "auto", + parallel_tool_calls: true, }), ) }) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index a750dfbeef5..daaad88d6c2 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -14,6 +14,7 @@ import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { handleOpenAIError } from "./utils/openai-error-handler" +import { isMcpTool } from "../../utils/mcp-name" const XAI_DEFAULT_TEMPERATURE = 0 @@ -48,6 +49,9 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler * Convert tools from OpenAI Chat Completions format to Responses API format. * Chat Completions: { type: "function", function: { name, description, parameters } } * Responses API: { type: "function", name, description, parameters } + * + * Uses base provider's convertToolSchemaForOpenAI() for schema hardening + * (additionalProperties: false, ensureAllRequired) and handles MCP tools. */ private mapResponseTools(tools?: any[]): any[] | undefined { if (!tools?.length) { @@ -55,13 +59,18 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler } return tools .filter((tool) => tool?.type === "function") - .map((tool) => ({ - type: "function", - name: tool.function.name, - description: tool.function.description, - parameters: tool.function.parameters ?? null, - strict: false, - })) + .map((tool) => { + const isMcp = isMcpTool(tool.function.name) + return { + type: "function", + name: tool.function.name, + description: tool.function.description, + parameters: isMcp + ? tool.function.parameters + : this.convertToolSchemaForOpenAI(tool.function.parameters), + strict: !isMcp, + } + }) } override async *createMessage( @@ -86,7 +95,9 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler stream: true, store: false, // Don't store responses server-side for privacy tools: responseTools, - tool_choice: responseTools ? "auto" : undefined, + // Cast tool_choice since metadata uses Chat Completions types but Responses API has its own type + tool_choice: (metadata?.tool_choice ?? (responseTools ? "auto" : undefined)) as any, + parallel_tool_calls: metadata?.parallelToolCalls ?? true, include: ["reasoning.encrypted_content"], }) } catch (error) { @@ -96,7 +107,7 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler throw handleOpenAIError(error, this.providerName) } - const normalizeUsage = createUsageNormalizer(model.info) + const normalizeUsage = createUsageNormalizer() yield* processResponsesApiStream(stream, normalizeUsage) } diff --git a/src/api/transform/__tests__/responses-api-input.spec.ts b/src/api/transform/__tests__/responses-api-input.spec.ts index c78345c7897..c57b61d8958 100644 --- a/src/api/transform/__tests__/responses-api-input.spec.ts +++ b/src/api/transform/__tests__/responses-api-input.spec.ts @@ -15,12 +15,18 @@ describe("convertToResponsesApiInput", () => { expect(result).toEqual([{ role: "user", content: [{ type: "input_text", text: "Hello" }] }]) }) - it("should convert assistant string content", () => { + it("should convert assistant string content to output_text message format", () => { const messages: Anthropic.Messages.MessageParam[] = [{ role: "assistant", content: "Hi there" }] const result = convertToResponsesApiInput(messages) - expect(result).toEqual([{ role: "assistant", content: [{ type: "input_text", text: "Hi there" }] }]) + expect(result).toEqual([ + { + type: "message", + role: "assistant", + content: [{ type: "output_text", text: "Hi there" }], + }, + ]) }) }) diff --git a/src/api/transform/__tests__/responses-api-stream.spec.ts b/src/api/transform/__tests__/responses-api-stream.spec.ts index 4d90ebaacb1..6abdbddb659 100644 --- a/src/api/transform/__tests__/responses-api-stream.spec.ts +++ b/src/api/transform/__tests__/responses-api-stream.spec.ts @@ -291,16 +291,14 @@ describe("processResponsesApiStream", () => { }) describe("createUsageNormalizer", () => { - const mockModelInfo = { contextWindow: 128000, supportsPromptCache: false } as any - it("should return undefined for null/undefined usage", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() expect(normalize(null)).toBeUndefined() expect(normalize(undefined)).toBeUndefined() }) it("should extract input and output tokens", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() const result = normalize({ input_tokens: 100, output_tokens: 50 }) @@ -314,7 +312,7 @@ describe("createUsageNormalizer", () => { }) it("should extract cached tokens from input_tokens_details", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() const result = normalize({ input_tokens: 100, @@ -325,8 +323,20 @@ describe("createUsageNormalizer", () => { expect(result?.cacheReadTokens).toBe(30) }) + it("should extract cache write tokens", () => { + const normalize = createUsageNormalizer() + + const result = normalize({ + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 15, + }) + + expect(result?.cacheWriteTokens).toBe(15) + }) + it("should extract reasoning tokens from output_tokens_details", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() const result = normalize({ input_tokens: 100, @@ -338,7 +348,7 @@ describe("createUsageNormalizer", () => { }) it("should not include reasoningTokens when not present", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() const result = normalize({ input_tokens: 100, output_tokens: 50 }) @@ -347,7 +357,7 @@ describe("createUsageNormalizer", () => { it("should compute totalCost when calculateCost is provided", () => { const calculateCost = (input: number, output: number, cached: number) => 0.42 - const normalize = createUsageNormalizer(mockModelInfo, calculateCost) + const normalize = createUsageNormalizer(calculateCost) const result = normalize({ input_tokens: 100, output_tokens: 50 }) @@ -355,7 +365,7 @@ describe("createUsageNormalizer", () => { }) it("should not include totalCost when calculateCost is not provided", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() const result = normalize({ input_tokens: 100, output_tokens: 50 }) @@ -363,7 +373,7 @@ describe("createUsageNormalizer", () => { }) it("should handle Chat Completions style field names as fallback", () => { - const normalize = createUsageNormalizer(mockModelInfo) + const normalize = createUsageNormalizer() const result = normalize({ prompt_tokens: 100, diff --git a/src/api/transform/responses-api-input.ts b/src/api/transform/responses-api-input.ts index bfcc52b5425..a766dfef6ef 100644 --- a/src/api/transform/responses-api-input.ts +++ b/src/api/transform/responses-api-input.ts @@ -18,10 +18,19 @@ export function convertToResponsesApiInput(messages: Anthropic.Messages.MessageP for (const message of messages) { if (typeof message.content === "string") { - input.push({ - role: message.role, - content: [{ type: "input_text", text: message.content }], - }) + if (message.role === "assistant") { + // Assistant messages use output_text in the Responses API format + input.push({ + type: "message", + role: "assistant", + content: [{ type: "output_text", text: message.content }], + }) + } else { + input.push({ + role: message.role, + content: [{ type: "input_text", text: message.content }], + }) + } continue } diff --git a/src/api/transform/responses-api-stream.ts b/src/api/transform/responses-api-stream.ts index ab0cdfbf94d..1ea9744180d 100644 --- a/src/api/transform/responses-api-stream.ts +++ b/src/api/transform/responses-api-stream.ts @@ -1,5 +1,3 @@ -import type { ModelInfo } from "@roo-code/types" - import type { ApiStream, ApiStreamUsageChunk } from "./stream" /** @@ -106,11 +104,9 @@ export async function* processResponsesApiStream( * Creates a standard usage normalizer for providers with per-token pricing. * Extracts input/output tokens, cache tokens, reasoning tokens, and computes cost. * - * @param modelInfo - Model info with pricing details * @param calculateCost - Optional function to compute total cost from token counts */ export function createUsageNormalizer( - modelInfo: ModelInfo, calculateCost?: (inputTokens: number, outputTokens: number, cacheReadTokens: number) => number, ): (usage: any) => ApiStreamUsageChunk | undefined { return (usage: any): ApiStreamUsageChunk | undefined => { @@ -122,6 +118,7 @@ export function createUsageNormalizer( const inputTokens = usage.input_tokens ?? usage.prompt_tokens ?? 0 const outputTokens = usage.output_tokens ?? usage.completion_tokens ?? 0 const cacheReadTokens = usage.cache_read_input_tokens ?? cachedTokens ?? 0 + const cacheWriteTokens = usage.cache_creation_input_tokens ?? usage.cache_write_tokens ?? 0 const reasoningTokens = typeof usage.output_tokens_details?.reasoning_tokens === "number" @@ -135,6 +132,7 @@ export function createUsageNormalizer( inputTokens, outputTokens, cacheReadTokens, + cacheWriteTokens, ...(typeof reasoningTokens === "number" ? { reasoningTokens } : {}), ...(typeof totalCost === "number" ? { totalCost } : {}), } From 69511334dc909e452023e78baff7fee850adb650 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Fri, 20 Mar 2026 00:12:00 +0000 Subject: [PATCH 3/3] fix: restore getModelParams, use base class tool conversion, add cacheWriteTokens, simplify completePrompt --- src/api/providers/__tests__/xai.spec.ts | 42 ++++++++++-- src/api/providers/xai.ts | 80 ++++++++++++++--------- src/api/transform/responses-api-stream.ts | 2 +- 3 files changed, 86 insertions(+), 38 deletions(-) diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index d760c0fd473..763d10d0277 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -239,12 +239,7 @@ describe("XAIHandler", () => { it("completePrompt should return text from Responses API", async () => { const expectedResponse = "This is a test response" mockResponsesCreate.mockResolvedValueOnce({ - output: [ - { - type: "message", - content: [{ type: "output_text", text: expectedResponse }], - }, - ], + output_text: expectedResponse, }) const result = await handler.completePrompt("test prompt") @@ -258,6 +253,41 @@ describe("XAIHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`) }) + it("should include reasoning_effort for mini models", async () => { + const miniModelHandler = new XAIHandler({ + apiModelId: "grok-3-mini", + reasoningEffort: "high", + }) + + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) + + const stream = miniModelHandler.createMessage("test prompt", []) + await stream.next() + + expect(mockResponsesCreate).toHaveBeenCalledWith( + expect.objectContaining({ + reasoning: expect.objectContaining({ + reasoning_effort: "high", + }), + }), + ) + }) + + it("should not include reasoning for non-mini models", async () => { + const regularHandler = new XAIHandler({ + apiModelId: "grok-3", + reasoningEffort: "high", + }) + + mockResponsesCreate.mockResolvedValueOnce(mockStream([])) + + const stream = regularHandler.createMessage("test prompt", []) + await stream.next() + + const callArgs = mockResponsesCreate.mock.calls[mockResponsesCreate.mock.calls.length - 1][0] + expect(callArgs).not.toHaveProperty("reasoning") + }) + it("should handle errors in createMessage", async () => { const errorMessage = "Stream error" mockResponsesCreate.mockRejectedValueOnce(new Error(errorMessage)) diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index daaad88d6c2..0cd9cb0273b 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -9,6 +9,7 @@ import type { ApiHandlerOptions } from "../../shared/api" import { ApiStream } from "../transform/stream" import { convertToResponsesApiInput } from "../transform/responses-api-input" import { processResponsesApiStream, createUsageNormalizer } from "../transform/responses-api-stream" +import { getModelParams } from "../transform/model-params" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" @@ -42,7 +43,15 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler ? (this.options.apiModelId as XAIModelId) : xaiDefaultModelId - return { id, info: xaiModels[id] } + const info = xaiModels[id] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: XAI_DEFAULT_TEMPERATURE, + }) + return { id, info, ...params } } /** @@ -54,10 +63,11 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler * (additionalProperties: false, ensureAllRequired) and handles MCP tools. */ private mapResponseTools(tools?: any[]): any[] | undefined { - if (!tools?.length) { + const converted = this.convertToolsForOpenAI(tools) + if (!converted?.length) { return undefined } - return tools + return converted .filter((tool) => tool?.type === "function") .map((tool) => { const isMcp = isMcpTool(tool.function.name) @@ -84,22 +94,42 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler const input = convertToResponsesApiInput(messages) const responseTools = this.mapResponseTools(metadata?.tools) - let stream + // Build request options + const requestBody: Record = { + model: model.id, + instructions: systemPrompt, + input: input, + stream: true, + store: false, // Don't store responses server-side for privacy + include: ["reasoning.encrypted_content"], + } + + if (model.maxTokens) { + requestBody.max_output_tokens = model.maxTokens + } + + if (model.temperature !== undefined) { + requestBody.temperature = model.temperature + } + + if (responseTools) { + requestBody.tools = responseTools + // Cast tool_choice since metadata uses Chat Completions types but Responses API has its own type + requestBody.tool_choice = (metadata?.tool_choice ?? "auto") as any + requestBody.parallel_tool_calls = metadata?.parallelToolCalls ?? true + } + + // Pass reasoning effort for models that support it (e.g., mini models) + if (model.reasoning) { + requestBody.reasoning = model.reasoning + } + + let stream: AsyncIterable try { - stream = await this.client.responses.create({ - model: model.id, - instructions: systemPrompt, - input: input, - max_output_tokens: model.info.maxTokens, - temperature: this.options.modelTemperature ?? XAI_DEFAULT_TEMPERATURE, + stream = (await this.client.responses.create({ + ...requestBody, stream: true, - store: false, // Don't store responses server-side for privacy - tools: responseTools, - // Cast tool_choice since metadata uses Chat Completions types but Responses API has its own type - tool_choice: (metadata?.tool_choice ?? (responseTools ? "auto" : undefined)) as any, - parallel_tool_calls: metadata?.parallelToolCalls ?? true, - include: ["reasoning.encrypted_content"], - }) + } as any)) as unknown as AsyncIterable } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "createMessage") @@ -121,20 +151,8 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler store: false, }) - // Extract text from the response output - const output = (response as any).output - if (Array.isArray(output)) { - for (const item of output) { - if (item.type === "message" && Array.isArray(item.content)) { - for (const content of item.content) { - if (content.type === "output_text" && content.text) { - return content.text - } - } - } - } - } - return (response as any).output_text || "" + // output_text is a convenience field on the Responses API response + return response.output_text || "" } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "completePrompt") diff --git a/src/api/transform/responses-api-stream.ts b/src/api/transform/responses-api-stream.ts index 1ea9744180d..884a7551367 100644 --- a/src/api/transform/responses-api-stream.ts +++ b/src/api/transform/responses-api-stream.ts @@ -131,8 +131,8 @@ export function createUsageNormalizer( type: "usage", inputTokens, outputTokens, - cacheReadTokens, cacheWriteTokens, + cacheReadTokens, ...(typeof reasoningTokens === "number" ? { reasoningTokens } : {}), ...(typeof totalCost === "number" ? { totalCost } : {}), }