From e8d994723329cc285837b6817379706a1014d621 Mon Sep 17 00:00:00 2001 From: im10furry <1936409761@qq.com> Date: Tue, 9 Jun 2026 14:01:29 +0800 Subject: [PATCH 1/4] refactor(mcp): extract connection/manager modules, add discovery caching and tests - Extract connection logic from client.ts into connection.ts (transport candidates, timeout, batch connect) - Add MCPClientManager in manager.ts with connection pooling, ping health checks, config-change detection, and auto-cleanup - Add per-client discovery caching in tools-integration.ts with ToolListChanged/PromptListChanged notification invalidation - Fix stderr listener timing in connection.ts: register before connectWithTimeout to capture handshake-phase output - Add assertJsonPayloadBudget validation on MCP tool arguments in mcp.ts - Use normalizeImageMediaType for safe Anthropic image block conversion - Fix mergeAbortSignals listener cleanup (track and remove all listeners) - Fix ListMcpResourcesTool indentation and add error logging to catch block - Add cache.clear() compatibility shims on getClients/getMCPTools/getMCPCommands - Add tests: MCPClientManager lifecycle, connection internals, content normalization --- src/acp/validation.ts | 113 +++++ src/entrypoints/mcp.ts | 5 + src/services/mcp/client.ts | 316 ++---------- src/services/mcp/connection.ts | 326 ++++++++++++ src/services/mcp/manager.ts | 171 +++++++ src/services/mcp/tools-integration.ts | 479 +++++++++++------- .../ListMcpResourcesTool.tsx | 19 +- src/utils/ai/anthropic.ts | 144 ++++++ tests/unit/mcp-connection-internals.test.ts | 77 +++ tests/unit/mcp-content-normalization.test.ts | 108 ++++ tests/unit/mcp-manager-lifecycle.test.ts | 194 +++++++ 11 files changed, 1484 insertions(+), 468 deletions(-) create mode 100644 src/acp/validation.ts create mode 100644 src/services/mcp/connection.ts create mode 100644 src/services/mcp/manager.ts create mode 100644 src/utils/ai/anthropic.ts create mode 100644 tests/unit/mcp-connection-internals.test.ts create mode 100644 tests/unit/mcp-content-normalization.test.ts create mode 100644 tests/unit/mcp-manager-lifecycle.test.ts diff --git a/src/acp/validation.ts b/src/acp/validation.ts new file mode 100644 index 000000000..d58635ba7 --- /dev/null +++ b/src/acp/validation.ts @@ -0,0 +1,113 @@ +export const MAX_JSON_PAYLOAD_BYTES = 1024 * 1024 +export const MAX_JSON_NESTING_DEPTH = 10 + +export type JsonPayloadBudgetErrorData = { + kind: 'payload_too_large' | 'payload_too_deep' | 'payload_not_serializable' + retryable: false + label: string + sizeBytes?: number + maxBytes?: number + depth?: number + maxDepth?: number +} + +export class JsonPayloadBudgetError extends Error { + readonly code = -32602 + readonly data: JsonPayloadBudgetErrorData + + constructor(message: string, data: JsonPayloadBudgetErrorData) { + super(message) + this.name = 'JsonPayloadBudgetError' + this.data = data + } +} + +function getJsonNestingDepth( + value: unknown, + seen: WeakSet = new WeakSet(), +): number { + if (value === null || typeof value !== 'object') return 0 + + if (seen.has(value)) { + throw new JsonPayloadBudgetError('JSON payload is not serializable', { + kind: 'payload_not_serializable', + retryable: false, + label: 'payload', + }) + } + seen.add(value) + + try { + const children = Array.isArray(value) + ? value + : Object.values(value as Record) + + if (children.length === 0) return 1 + return ( + 1 + Math.max(...children.map(child => getJsonNestingDepth(child, seen))) + ) + } finally { + seen.delete(value) + } +} + +export function getJsonPayloadBudget(value: unknown): { + sizeBytes: number + depth: number +} { + let serialized: string + try { + serialized = JSON.stringify(value) ?? 'null' + } catch { + throw new JsonPayloadBudgetError('JSON payload is not serializable', { + kind: 'payload_not_serializable', + retryable: false, + label: 'payload', + }) + } + + return { + sizeBytes: Buffer.byteLength(serialized, 'utf8'), + depth: getJsonNestingDepth(value), + } +} + +export function assertJsonPayloadBudget( + value: unknown, + options?: { + label?: string + maxBytes?: number + maxDepth?: number + }, +): void { + const label = options?.label ?? 'payload' + const maxBytes = options?.maxBytes ?? MAX_JSON_PAYLOAD_BYTES + const maxDepth = options?.maxDepth ?? MAX_JSON_NESTING_DEPTH + const budget = getJsonPayloadBudget(value) + + if (budget.sizeBytes > maxBytes) { + throw new JsonPayloadBudgetError( + `${label} exceeds maximum serialized size of ${maxBytes} bytes`, + { + kind: 'payload_too_large', + retryable: false, + label, + sizeBytes: budget.sizeBytes, + maxBytes, + }, + ) + } + + if (budget.depth > maxDepth) { + throw new JsonPayloadBudgetError( + `${label} exceeds maximum nesting depth of ${maxDepth}`, + { + kind: 'payload_too_deep', + retryable: false, + label, + depth: budget.depth, + maxDepth, + }, + ) + } +} diff --git a/src/entrypoints/mcp.ts b/src/entrypoints/mcp.ts index 131b923f2..f706be2b8 100644 --- a/src/entrypoints/mcp.ts +++ b/src/entrypoints/mcp.ts @@ -19,6 +19,7 @@ import { Command } from '@commands' import review from '@commands/review' import { lastX } from '@utils/text/generators' import { MACRO } from '@constants/macros' +import { assertJsonPayloadBudget } from '../acp/validation' type ToolInput = Record const state: { @@ -74,6 +75,10 @@ export async function startMCPServer(cwd: string): Promise { } try { + assertJsonPayloadBudget(args ?? {}, { + label: `MCP tool ${name} arguments`, + }) + if (!(await tool.isEnabled())) { throw new Error(`Tool ${name} is not enabled`) } diff --git a/src/services/mcp/client.ts b/src/services/mcp/client.ts index a6a04503e..1fef52c99 100644 --- a/src/services/mcp/client.ts +++ b/src/services/mcp/client.ts @@ -7,232 +7,17 @@ import { import { existsSync, readFileSync } from 'fs' import { resolve } from 'path' import { getCwd } from '@utils/state' -import { Client } from '@modelcontextprotocol/sdk/client/index.js' -import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' -import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' -import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js' -import { memoize, pickBy } from 'lodash-es' -import { logMCPError } from '@utils/log' -import { PRODUCT_COMMAND } from '@constants/product' +import { pickBy } from 'lodash-es' import { parseJsonOrJsonc } from './internal/jsonc' import { getMcprcServerStatus, listPluginMCPServers } from './discovery' +import { MCPClientManager } from './manager' +import type { WrappedClient } from './connection' -function getMcpServerConnectionBatchSize(): number { - const raw = process.env.MCP_SERVER_CONNECTION_BATCH_SIZE - const parsed = raw ? Number.parseInt(raw, 10) : NaN - if (Number.isFinite(parsed) && parsed > 0 && parsed <= 50) return parsed - return 3 -} - -async function connectToServer( - name: string, - serverRef: McpServerConfig, -): Promise { - type Candidate = { transport: unknown; kind: 'stdio' | 'sse' | 'http' | 'ws' } - - const ensureWebSocketGlobal = async () => { - if (typeof (globalThis as any).WebSocket === 'function') return - try { - const undici = await import('undici') - if (typeof (undici as any).WebSocket === 'function') { - ;(globalThis as any).WebSocket = (undici as any).WebSocket - } - } catch {} - } - - const candidates: Candidate[] = await (async () => { - switch (serverRef.type) { - case 'sse': { - const ref = serverRef - return [ - { - kind: 'sse', - transport: new SSEClientTransport(new URL(ref.url), { - ...(ref.headers ? { requestInit: { headers: ref.headers } } : {}), - }), - }, - { - kind: 'http', - transport: new StreamableHTTPClientTransport(new URL(ref.url), { - ...(ref.headers ? { requestInit: { headers: ref.headers } } : {}), - }), - }, - ] - } - case 'sse-ide': { - const ref = serverRef - return [ - { - kind: 'sse', - transport: new SSEClientTransport(new URL(ref.url), { - ...(ref.headers ? { requestInit: { headers: ref.headers } } : {}), - }), - }, - ] - } - case 'http': { - const ref = serverRef - return [ - { - kind: 'http', - transport: new StreamableHTTPClientTransport(new URL(ref.url), { - ...(ref.headers ? { requestInit: { headers: ref.headers } } : {}), - }), - }, - { - kind: 'sse', - transport: new SSEClientTransport(new URL(ref.url), { - ...(ref.headers ? { requestInit: { headers: ref.headers } } : {}), - }), - }, - ] - } - case 'ws': { - const ref = serverRef - await ensureWebSocketGlobal() - return [ - { - kind: 'ws', - transport: new WebSocketClientTransport(new URL(ref.url)), - }, - ] - } - case 'ws-ide': { - const ref = serverRef - - let url = ref.url - if (ref.authToken) { - try { - const parsed = new URL(url) - if (!parsed.searchParams.has('authToken')) { - parsed.searchParams.set('authToken', ref.authToken) - url = parsed.toString() - } - } catch {} - } - - await ensureWebSocketGlobal() - return [ - { - kind: 'ws', - transport: new WebSocketClientTransport(new URL(url)), - }, - ] - } - case 'stdio': - default: { - const ref = serverRef - return [ - { - kind: 'stdio', - transport: new StdioClientTransport({ - command: ref.command, - args: ref.args, - env: { - ...process.env, - ...ref.env, - } as Record, - stderr: 'pipe', - }), - }, - ] - } - } - })() - - const rawTimeout = process.env.MCP_CONNECTION_TIMEOUT_MS - const parsedTimeout = rawTimeout ? Number.parseInt(rawTimeout, 10) : NaN - const CONNECTION_TIMEOUT_MS = Number.isFinite(parsedTimeout) - ? parsedTimeout - : 30_000 - - let lastError: unknown - - for (const candidate of candidates) { - const client = new Client( - { - name: PRODUCT_COMMAND, - version: '0.1.0', - }, - { - capabilities: {}, - }, - ) - - try { - const connectPromise = client.connect(candidate.transport as any) - if (CONNECTION_TIMEOUT_MS > 0) { - const timeoutPromise = new Promise((_, reject) => { - const timeoutId = setTimeout(() => { - reject( - new Error( - `Connection to MCP server "${name}" timed out after ${CONNECTION_TIMEOUT_MS}ms`, - ), - ) - }, CONNECTION_TIMEOUT_MS) - - connectPromise.then( - () => clearTimeout(timeoutId), - () => clearTimeout(timeoutId), - ) - }) - - await Promise.race([connectPromise, timeoutPromise]) - } else { - await connectPromise - } - - if (candidate.kind === 'stdio') { - ;(candidate.transport as StdioClientTransport).stderr?.on( - 'data', - (data: Buffer) => { - const errorText = data.toString().trim() - if (errorText) { - logMCPError(name, `Server stderr: ${errorText}`) - } - }, - ) - } - - if (candidates.length > 1 && candidate !== candidates[0]) { - logMCPError( - name, - `Connected using fallback transport "${candidate.kind}". Consider setting the server type explicitly in your MCP config.`, - ) - } - - return client - } catch (error) { - lastError = error - try { - await client.close() - } catch {} - } - } - - throw lastError instanceof Error - ? lastError - : new Error(`Failed to connect to MCP server "${name}"`) -} - -type ConnectedClient = { - client: Client - capabilities?: Record | null - name: string - type: 'connected' -} -type FailedClient = { - name: string - type: 'failed' -} -export type WrappedClient = ConnectedClient | FailedClient +export type { WrappedClient } from './connection' -export const getClients = memoize(async (): Promise => { - if (process.env.CI && process.env.NODE_ENV !== 'test') { - return [] - } +const mcpClientManager = new MCPClientManager() +function getConfiguredMcpServers(): Record { const pluginServers = listPluginMCPServers() const globalServers = getGlobalConfig().mcpServers ?? {} const projectFileServers = getProjectMcpServerDefinitions().servers @@ -243,47 +28,34 @@ export const getClients = memoize(async (): Promise => { (_, name) => getMcprcServerStatus(name) === 'approved', ) - const allServers = { + return { ...pluginServers, ...globalServers, ...approvedProjectFileServers, ...projectServers, } +} - const batchSize = getMcpServerConnectionBatchSize() - const entries = Object.entries(allServers) - const results: WrappedClient[] = [] +type GetClientsFn = (() => Promise) & { + cache: { clear: () => void } +} - for (let i = 0; i < entries.length; i += batchSize) { - const batch = entries.slice(i, i + batchSize) - const batchResults = await Promise.all( - batch.map(async ([name, serverRef]) => { - try { - const client = await connectToServer( - name, - serverRef as McpServerConfig, - ) - let capabilities: Record | null = null - try { - capabilities = client.getServerCapabilities() as any - } catch { - capabilities = null - } - return { name, client, capabilities, type: 'connected' as const } - } catch (error) { - logMCPError( - name, - `Connection failed: ${error instanceof Error ? error.message : String(error)}`, - ) - return { name, type: 'failed' as const } - } - }), - ) - results.push(...batchResults) - } +export const getClients: GetClientsFn = Object.assign( + async (): Promise => { + if (process.env.CI && process.env.NODE_ENV !== 'test') { + return [] + } - return results -}) + return mcpClientManager.getClientsForServers(getConfiguredMcpServers()) + }, + { + cache: { + clear: () => { + mcpClientManager.clear() + }, + }, + }, +) function parseMcpServersFromCliConfigEntries(options: { entries: string[] @@ -373,37 +145,7 @@ export async function getClientsForCliMcpConfig(options: { ...(cliServers ?? {}), } - const batchSize = getMcpServerConnectionBatchSize() - const entriesToConnect = Object.entries(allServers) - const results: WrappedClient[] = [] - - for (let i = 0; i < entriesToConnect.length; i += batchSize) { - const batch = entriesToConnect.slice(i, i + batchSize) - const batchResults = await Promise.all( - batch.map(async ([name, serverRef]) => { - try { - const client = await connectToServer( - name, - serverRef as McpServerConfig, - ) - let capabilities: Record | null = null - try { - capabilities = client.getServerCapabilities() as any - } catch { - capabilities = null - } - return { name, client, capabilities, type: 'connected' as const } - } catch (error) { - logMCPError( - name, - `Connection failed: ${error instanceof Error ? error.message : String(error)}`, - ) - return { name, type: 'failed' as const } - } - }), - ) - results.push(...batchResults) - } - - return results + return mcpClientManager.getClientsForServers(allServers, { + closeMissing: false, + }) } diff --git a/src/services/mcp/connection.ts b/src/services/mcp/connection.ts new file mode 100644 index 000000000..f13722056 --- /dev/null +++ b/src/services/mcp/connection.ts @@ -0,0 +1,326 @@ +import type { McpServerConfig } from '@utils/config' +import { PRODUCT_COMMAND } from '@constants/product' +import { logMCPError } from '@utils/log' +import { debug } from '@utils/log/debugLogger' +import { Client } from '@modelcontextprotocol/sdk/client/index.js' +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' +import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js' + +export type ConnectedClient = { + client: Client + capabilities?: Record | null + name: string + type: 'connected' +} + +export type FailedClient = { + name: string + type: 'failed' +} + +export type WrappedClient = ConnectedClient | FailedClient + +type TransportKind = 'stdio' | 'sse' | 'http' | 'ws' + +type TransportCandidate = { + kind: TransportKind + transport: unknown +} + +export function getMcpServerConnectionBatchSize(): number { + const raw = process.env.MCP_SERVER_CONNECTION_BATCH_SIZE + const parsed = raw ? Number.parseInt(raw, 10) : NaN + if (Number.isFinite(parsed) && parsed > 0 && parsed <= 50) return parsed + return 3 +} + +export function getMcpConnectionTimeoutMs(): number { + const rawTimeout = process.env.MCP_CONNECTION_TIMEOUT_MS + const parsedTimeout = rawTimeout ? Number.parseInt(rawTimeout, 10) : NaN + return Number.isFinite(parsedTimeout) ? parsedTimeout : 30_000 +} + +async function ensureWebSocketGlobal(): Promise { + if (typeof (globalThis as any).WebSocket === 'function') return + try { + const undici = await import('undici') + if (typeof (undici as any).WebSocket === 'function') { + ;(globalThis as any).WebSocket = (undici as any).WebSocket + } + } catch {} +} + +export async function createMcpTransportCandidates( + serverRef: McpServerConfig, +): Promise { + switch (serverRef.type) { + case 'sse': + return [ + { + kind: 'sse', + transport: new SSEClientTransport(new URL(serverRef.url), { + ...(serverRef.headers + ? { requestInit: { headers: serverRef.headers } } + : {}), + }), + }, + { + kind: 'http', + transport: new StreamableHTTPClientTransport(new URL(serverRef.url), { + ...(serverRef.headers + ? { requestInit: { headers: serverRef.headers } } + : {}), + }), + }, + ] + case 'sse-ide': + return [ + { + kind: 'sse', + transport: new SSEClientTransport(new URL(serverRef.url), { + ...(serverRef.headers + ? { requestInit: { headers: serverRef.headers } } + : {}), + }), + }, + ] + case 'http': + return [ + { + kind: 'http', + transport: new StreamableHTTPClientTransport(new URL(serverRef.url), { + ...(serverRef.headers + ? { requestInit: { headers: serverRef.headers } } + : {}), + }), + }, + { + kind: 'sse', + transport: new SSEClientTransport(new URL(serverRef.url), { + ...(serverRef.headers + ? { requestInit: { headers: serverRef.headers } } + : {}), + }), + }, + ] + case 'ws': + await ensureWebSocketGlobal() + return [ + { + kind: 'ws', + transport: new WebSocketClientTransport(new URL(serverRef.url)), + }, + ] + case 'ws-ide': { + let url = serverRef.url + if (serverRef.authToken) { + try { + const parsed = new URL(url) + if (!parsed.searchParams.has('authToken')) { + parsed.searchParams.set('authToken', serverRef.authToken) + url = parsed.toString() + } + } catch {} + } + + await ensureWebSocketGlobal() + return [ + { + kind: 'ws', + transport: new WebSocketClientTransport(new URL(url)), + }, + ] + } + case 'stdio': + default: + return [ + { + kind: 'stdio', + transport: new StdioClientTransport({ + command: serverRef.command, + args: serverRef.args, + env: { + ...process.env, + ...serverRef.env, + } as Record, + stderr: 'pipe', + }), + }, + ] + } +} + +async function connectWithTimeout( + client: Client, + transport: unknown, + serverName: string, + timeoutMs: number, +): Promise { + const connectPromise = client.connect(transport as any) + if (timeoutMs <= 0) { + await connectPromise + return + } + + const timeoutPromise = new Promise((_, reject) => { + const timeoutId = setTimeout(() => { + reject( + new Error( + `Connection to MCP server "${serverName}" timed out after ${timeoutMs}ms`, + ), + ) + }, timeoutMs) + + connectPromise.then( + () => clearTimeout(timeoutId), + () => clearTimeout(timeoutId), + ) + }) + + await Promise.race([connectPromise, timeoutPromise]) +} + +export function captureMcpCapabilities( + client: Client, +): Record | null { + try { + return client.getServerCapabilities() as any + } catch { + return null + } +} + +export async function connectMcpClient( + name: string, + serverRef: McpServerConfig, + options?: { clientVersion?: string }, +): Promise { + const candidates = await createMcpTransportCandidates(serverRef) + const timeoutMs = getMcpConnectionTimeoutMs() + const startedAt = Date.now() + let lastError: unknown + + debug.info('MCP_CONNECT_START', { + server: name, + type: serverRef.type ?? 'stdio', + candidates: candidates.map(candidate => candidate.kind), + timeoutMs, + }) + + for (const candidate of candidates) { + const client = new Client( + { + name: PRODUCT_COMMAND, + version: options?.clientVersion ?? '0.1.0', + }, + { + capabilities: {}, + }, + ) + + try { + if (candidate.kind === 'stdio') { + ;(candidate.transport as StdioClientTransport).stderr?.on( + 'data', + (data: Buffer) => { + const errorText = data.toString().trim() + if (errorText) { + logMCPError(name, `Server stderr: ${errorText}`) + } + }, + ) + } + + await connectWithTimeout(client, candidate.transport, name, timeoutMs) + + if (candidates.length > 1 && candidate !== candidates[0]) { + logMCPError( + name, + `Connected using fallback transport "${candidate.kind}". Consider setting the server type explicitly in your MCP config.`, + ) + } + + debug.info('MCP_CONNECT_SUCCESS', { + server: name, + transport: candidate.kind, + durationMs: Date.now() - startedAt, + }) + + return client + } catch (error) { + lastError = error + try { + await client.close() + } catch {} + } + } + + debug.warn('MCP_CONNECT_FAILED', { + server: name, + durationMs: Date.now() - startedAt, + error: lastError instanceof Error ? lastError.message : String(lastError), + }) + + throw lastError instanceof Error + ? lastError + : new Error(`Failed to connect to MCP server "${name}"`) +} + +export async function connectMcpServer( + name: string, + serverRef: McpServerConfig, + options?: { clientVersion?: string }, +): Promise { + try { + const client = await connectMcpClient(name, serverRef, options) + return { + name, + client, + capabilities: captureMcpCapabilities(client), + type: 'connected', + } + } catch (error) { + logMCPError( + name, + `Connection failed: ${error instanceof Error ? error.message : String(error)}`, + ) + return { name, type: 'failed' } + } +} + +export async function connectMcpServers( + servers: Record, + options?: { clientVersion?: string }, +): Promise { + const batchSize = getMcpServerConnectionBatchSize() + const entries = Object.entries(servers) + const results: WrappedClient[] = [] + + for (let i = 0; i < entries.length; i += batchSize) { + const batch = entries.slice(i, i + batchSize) + const startedAt = Date.now() + debug.info('MCP_CONNECT_BATCH_START', { + offset: i, + size: batch.length, + total: entries.length, + }) + const batchResults = await Promise.all( + batch.map(([name, serverRef]) => + connectMcpServer(name, serverRef, options), + ), + ) + debug.info('MCP_CONNECT_BATCH_DONE', { + offset: i, + size: batch.length, + durationMs: Date.now() - startedAt, + connected: batchResults.filter(result => result.type === 'connected') + .length, + failed: batchResults.filter(result => result.type === 'failed').length, + }) + results.push(...batchResults) + } + + return results +} diff --git a/src/services/mcp/manager.ts b/src/services/mcp/manager.ts new file mode 100644 index 000000000..59498c53c --- /dev/null +++ b/src/services/mcp/manager.ts @@ -0,0 +1,171 @@ +import type { McpServerConfig } from '@utils/config' +import { debug } from '@utils/log/debugLogger' +import { + captureMcpCapabilities, + connectMcpServer, + getMcpConnectionTimeoutMs, + getMcpServerConnectionBatchSize, + type WrappedClient, +} from './connection' + +type ManagedClientEntry = { + configKey: string + serverRef: McpServerConfig + wrapped: WrappedClient + lastConnectAttemptAt: number + lastHealthCheckAt: number +} + +const HEALTH_CHECK_INTERVAL_MS = 5_000 +const FAILED_RETRY_INTERVAL_MS = 30_000 + +function stableStringify(value: unknown): string { + if (value === undefined) return 'undefined' + if (value === null || typeof value !== 'object') return JSON.stringify(value) + if (Array.isArray(value)) { + return `[${value.map(item => stableStringify(item)).join(',')}]` + } + + const entries = Object.entries(value as Record).sort( + ([a], [b]) => a.localeCompare(b), + ) + return `{${entries + .map(([key, item]) => `${JSON.stringify(key)}:${stableStringify(item)}`) + .join(',')}}` +} + +function serverConfigKey(name: string, serverRef: McpServerConfig): string { + return `${name}:${stableStringify(serverRef)}` +} + +async function closeWrappedClient(client: WrappedClient): Promise { + if (client.type !== 'connected') return + try { + await client.client.close() + } catch {} +} + +async function pingWrappedClient(client: WrappedClient): Promise { + if (client.type !== 'connected') return false + + const configuredTimeoutMs = getMcpConnectionTimeoutMs() + const timeoutMs = + configuredTimeoutMs > 0 ? Math.min(configuredTimeoutMs, 5_000) : 5_000 + + try { + await client.client.ping({ timeout: timeoutMs }) + client.capabilities = captureMcpCapabilities(client.client) + return true + } catch { + return false + } +} + +export class MCPClientManager { + private readonly clients = new Map() + + async getClientsForServers( + servers: Record, + options?: { clientVersion?: string; closeMissing?: boolean }, + ): Promise { + const entries = Object.entries(servers) + const activeNames = new Set(entries.map(([name]) => name)) + + if (options?.closeMissing !== false) { + for (const [name, entry] of this.clients.entries()) { + if (activeNames.has(name)) continue + this.clients.delete(name) + void closeWrappedClient(entry.wrapped) + } + } + + const batchSize = getMcpServerConnectionBatchSize() + const results: WrappedClient[] = [] + + for (let i = 0; i < entries.length; i += batchSize) { + const batch = entries.slice(i, i + batchSize) + const startedAt = Date.now() + + debug.info('MCP_MANAGER_BATCH_START', { + offset: i, + size: batch.length, + total: entries.length, + }) + + const batchResults = await Promise.all( + batch.map(([name, serverRef]) => + this.getClientForServer(name, serverRef, options), + ), + ) + + debug.info('MCP_MANAGER_BATCH_DONE', { + offset: i, + size: batch.length, + durationMs: Date.now() - startedAt, + connected: batchResults.filter(result => result.type === 'connected') + .length, + failed: batchResults.filter(result => result.type === 'failed').length, + }) + + results.push(...batchResults) + } + + return results + } + + clear(): void { + for (const entry of this.clients.values()) { + void closeWrappedClient(entry.wrapped) + } + this.clients.clear() + } + + private async getClientForServer( + name: string, + serverRef: McpServerConfig, + options?: { clientVersion?: string }, + ): Promise { + const now = Date.now() + const configKey = serverConfigKey(name, serverRef) + const existing = this.clients.get(name) + + if (existing && existing.configKey !== configKey) { + this.clients.delete(name) + void closeWrappedClient(existing.wrapped) + } else if (existing) { + if (existing.wrapped.type === 'connected') { + if (now - existing.lastHealthCheckAt < HEALTH_CHECK_INTERVAL_MS) { + return existing.wrapped + } + + existing.lastHealthCheckAt = now + const healthy = await pingWrappedClient(existing.wrapped) + if (healthy) return existing.wrapped + + debug.warn('MCP_MANAGER_RECONNECT_AFTER_PING_FAILED', { + server: name, + }) + this.clients.delete(name) + void closeWrappedClient(existing.wrapped) + } else if ( + now - existing.lastConnectAttemptAt < + FAILED_RETRY_INTERVAL_MS + ) { + return existing.wrapped + } + } + + const wrapped = await connectMcpServer(name, serverRef, { + clientVersion: options?.clientVersion, + }) + this.clients.set(name, { + configKey, + serverRef, + wrapped, + lastConnectAttemptAt: now, + lastHealthCheckAt: now, + }) + + return wrapped + } +} diff --git a/src/services/mcp/tools-integration.ts b/src/services/mcp/tools-integration.ts index 1d7e11010..ad6c2edcf 100644 --- a/src/services/mcp/tools-integration.ts +++ b/src/services/mcp/tools-integration.ts @@ -1,15 +1,14 @@ -import { zipObject, memoize } from 'lodash-es' +import { zipObject } from 'lodash-es' import type { Tool } from '@tool' import { MCPTool } from '@tools/mcp/MCPTool/MCPTool' import { logMCPError } from '@utils/log' +import { debug } from '@utils/log/debugLogger' import type { Command } from '@commands' -import type { - ImageBlockParam, - MessageParam, - ToolResultBlockParam, -} from '@anthropic-ai/sdk/resources/index.mjs' +import type { MessageParam, ToolResultBlockParam } from '@anthropic-ai/sdk/resources/index.mjs' import { CallToolResultSchema, + PromptListChangedNotificationSchema, + ToolListChangedNotificationSchema, type ClientRequest, type ListPromptsResult, ListPromptsResultSchema, @@ -18,10 +17,24 @@ import { type Result, ResultSchema, } from '@modelcontextprotocol/sdk/types.js' +import { normalizeImageMediaType } from '@utils/ai/anthropic' import { getClients, type WrappedClient } from './client' type ConnectedClient = Extract +type CachedDiscovery = { + client: ConnectedClient + promise: Promise +} + +const toolDiscoveryCache = new Map>() +const promptDiscoveryCache = new Map< + string, + CachedDiscovery +>() +const toolNotificationClients = new WeakSet() +const promptNotificationClients = new WeakSet() + function sanitizeMcpIdentifierPart(value: string): string { return value.replace(/[^a-zA-Z0-9_-]/g, '_') } @@ -54,6 +67,7 @@ function mergeAbortSignals( if (active.length === 1) return { signal: active[0]!, cleanup: () => {} } const controller = new AbortController() + const listeners: Array<{ signal: AbortSignal; abort: () => void }> = [] const abort = () => { try { @@ -61,15 +75,26 @@ function mergeAbortSignals( } catch {} } - for (const s of active) { - if (s.aborted) { + for (const signal of active) { + if (signal.aborted) { abort() + for (const listener of listeners) { + listener.signal.removeEventListener('abort', listener.abort) + } return { signal: controller.signal, cleanup: () => {} } } - s.addEventListener('abort', abort, { once: true }) + signal.addEventListener('abort', abort, { once: true }) + listeners.push({ signal, abort }) } - return { signal: controller.signal, cleanup: () => {} } + return { + signal: controller.signal, + cleanup: () => { + for (const listener of listeners) { + listener.signal.removeEventListener('abort', listener.abort) + } + }, + } } const IDE_MCP_TOOL_ALLOWLIST = new Set([ @@ -77,148 +102,234 @@ const IDE_MCP_TOOL_ALLOWLIST = new Set([ 'mcp__ide__getDiagnostics', ]) -async function requestAll< +function getServerCapabilities( + client: ConnectedClient, +): Record | null { + let capabilities: Record | null = client.capabilities ?? null + if (capabilities) return capabilities + + try { + capabilities = client.client.getServerCapabilities() as any + } catch { + capabilities = null + } + client.capabilities = capabilities + return capabilities +} + +function hasCapability( + client: ConnectedClient, + requiredCapability: string, +): boolean { + const capabilities = getServerCapabilities(client) + return Boolean((capabilities as any)?.[requiredCapability]) +} + +function supportsListChanged( + client: ConnectedClient, + capability: 'tools' | 'prompts', +): boolean { + const capabilities = getServerCapabilities(client) + return Boolean((capabilities as any)?.[capability]?.listChanged) +} + +function registerToolListChangedHandler(client: ConnectedClient): void { + if (!supportsListChanged(client, 'tools')) return + if (toolNotificationClients.has(client.client)) return + + client.client.setNotificationHandler( + ToolListChangedNotificationSchema, + async () => { + toolDiscoveryCache.delete(client.name) + debug.info('MCP_TOOLS_CACHE_INVALIDATED', { server: client.name }) + }, + ) + toolNotificationClients.add(client.client) +} + +function registerPromptListChangedHandler(client: ConnectedClient): void { + if (!supportsListChanged(client, 'prompts')) return + if (promptNotificationClients.has(client.client)) return + + client.client.setNotificationHandler( + PromptListChangedNotificationSchema, + async () => { + promptDiscoveryCache.delete(client.name) + debug.info('MCP_PROMPTS_CACHE_INVALIDATED', { server: client.name }) + }, + ) + promptNotificationClients.add(client.client) +} + +async function requestFromClient< ResultT extends Result, ResultSchemaT extends typeof ResultSchema, >( + client: ConnectedClient, req: ClientRequest, resultSchema: ResultSchemaT, - requiredCapability: string, -): Promise<{ client: ConnectedClient; result: ResultT }[]> { +): Promise { const timeoutMs = getMcpToolTimeoutMs() - const clients = await getClients() - const results = await Promise.allSettled( - clients.map(async client => { - if (client.type === 'failed') return null - - let timeoutSignal: TimeoutSignal | null = null + let timeoutSignal: TimeoutSignal | null = null + let merged: { signal: AbortSignal; cleanup: () => void } | null = null + const startedAt = Date.now() - try { - let capabilities: Record | null = - client.capabilities ?? null + debug.info('MCP_DISCOVERY_REFRESH_START', { + server: client.name, + method: req.method, + }) - if (!capabilities) { - try { - capabilities = client.client.getServerCapabilities() as any - } catch { - capabilities = null - } - client.capabilities = capabilities - } + try { + timeoutSignal = timeoutMs ? createTimeoutSignal(timeoutMs) : null + merged = mergeAbortSignals([timeoutSignal?.signal]) - if (!(capabilities as any)?.[requiredCapability]) { - return null - } + const result = (await client.client.request( + req, + resultSchema, + merged?.signal ? ({ signal: merged.signal } as any) : undefined, + )) as ResultT - timeoutSignal = timeoutMs ? createTimeoutSignal(timeoutMs) : null - const merged = mergeAbortSignals([timeoutSignal?.signal]) + debug.info('MCP_DISCOVERY_REFRESH_DONE', { + server: client.name, + method: req.method, + durationMs: Date.now() - startedAt, + }) - return { - client, - result: (await client.client.request( - req, - resultSchema, - merged?.signal ? ({ signal: merged.signal } as any) : undefined, - )) as ResultT, - } - } catch (error) { - if (client.type === 'connected') { - logMCPError( - client.name, - `Failed to request '${req.method}': ${error instanceof Error ? error.message : String(error)}`, - ) - } - return null - } finally { - timeoutSignal?.cleanup() - } - }), - ) - return results - .filter( - ( - result, - ): result is PromiseFulfilledResult<{ - client: ConnectedClient - result: ResultT - } | null> => result.status === 'fulfilled', - ) - .map(result => result.value) - .filter( - (result): result is { client: ConnectedClient; result: ResultT } => - result !== null, + return result + } catch (error) { + logMCPError( + client.name, + `Failed to request '${req.method}': ${error instanceof Error ? error.message : String(error)}`, ) + return null + } finally { + merged?.cleanup() + timeoutSignal?.cleanup() + } } -export const getMCPTools = memoize(async (): Promise => { - const toolsList = await requestAll< - ListToolsResult, - typeof ListToolsResultSchema - >( - { - method: 'tools/list', - }, - ListToolsResultSchema, - 'tools', +async function getCachedServerDiscovery< + ResultT extends Result, + ResultSchemaT extends typeof ResultSchema, +>(options: { + client: ConnectedClient + cache: Map> + req: ClientRequest + resultSchema: ResultSchemaT + requiredCapability: 'tools' | 'prompts' +}): Promise { + const { client, cache, req, resultSchema, requiredCapability } = options + + if (!hasCapability(client, requiredCapability)) return null + if (requiredCapability === 'tools') registerToolListChangedHandler(client) + else registerPromptListChangedHandler(client) + + const cached = cache.get(client.name) + if (cached && cached.client.client === client.client) return cached.promise + + const promise = requestFromClient( + client, + req, + resultSchema, + ).then(result => { + if (result === null) cache.delete(client.name) + return result + }) + cache.set(client.name, { client, promise }) + return promise +} + +async function getConnectedClients(): Promise { + const clients = await getClients() + return clients.filter( + (client): client is ConnectedClient => client.type === 'connected', ) +} - return toolsList.flatMap(({ client, result: { tools } }) => { - const serverPart = sanitizeMcpIdentifierPart(client.name) +type CacheClearableFn = (() => Promise) & { cache: { clear: () => void } } - return tools - .map((tool): Tool | null => { - const toolPart = sanitizeMcpIdentifierPart(tool.name) - const name = `mcp__${serverPart}__${toolPart}` +export const getMCPTools: CacheClearableFn = Object.assign( + async (): Promise => { + const clients = await getConnectedClients() + const toolsList = await Promise.all( + clients.map(async client => ({ + client, + result: await getCachedServerDiscovery({ + client, + cache: toolDiscoveryCache, + req: { method: 'tools/list' }, + resultSchema: ListToolsResultSchema, + requiredCapability: 'tools', + }), + })), + ) - if ( - name.startsWith('mcp__ide__') && - !IDE_MCP_TOOL_ALLOWLIST.has(name) - ) { - return null - } + return toolsList.flatMap(({ client, result }) => { + if (!result) return [] + const serverPart = sanitizeMcpIdentifierPart(client.name) - return { - ...MCPTool, - name, - isConcurrencySafe() { - return tool.annotations?.readOnlyHint ?? false - }, - isReadOnly() { - return tool.annotations?.readOnlyHint ?? false - }, - async description() { - return tool.description ?? '' - }, - async prompt() { - return tool.description ?? '' - }, - inputJSONSchema: tool.inputSchema as Tool['inputJSONSchema'], - async validateInput() { - return { result: true } - }, - async *call(args: Record, context) { - const data = await callMCPTool({ - client, - tool: tool.name, - args, - toolUseId: context.toolUseId, - signal: context.abortController.signal, - }) - yield { - type: 'result' as const, - data, - resultForAssistant: data, - } - }, - userFacingName() { - const title = tool.annotations?.title || tool.name - return `${client.name} - ${title} (MCP)` - }, - } - }) - .filter((tool): tool is Tool => tool !== null) - }) -}) + return result.tools + .map((tool): Tool | null => { + const toolPart = sanitizeMcpIdentifierPart(tool.name) + const name = `mcp__${serverPart}__${toolPart}` + + if ( + name.startsWith('mcp__ide__') && + !IDE_MCP_TOOL_ALLOWLIST.has(name) + ) { + return null + } + + return { + ...MCPTool, + name, + isConcurrencySafe() { + return tool.annotations?.readOnlyHint ?? false + }, + isReadOnly() { + return tool.annotations?.readOnlyHint ?? false + }, + async description() { + return tool.description ?? '' + }, + async prompt() { + return tool.description ?? '' + }, + inputJSONSchema: tool.inputSchema as Tool['inputJSONSchema'], + async validateInput() { + return { result: true } + }, + async *call(args: Record, context) { + const data = await callMCPTool({ + client, + tool: tool.name, + args, + toolUseId: context.toolUseId, + signal: context.abortController.signal, + }) + yield { + type: 'result' as const, + data, + resultForAssistant: data, + } + }, + userFacingName() { + const title = tool.annotations?.title || tool.name + return `${client.name} - ${title} (MCP)` + }, + } + }) + .filter((tool): tool is Tool => tool !== null) + }) + }, + { + cache: { + clear: () => { + toolDiscoveryCache.clear() + }, + }, + }, +) async function callMCPTool({ client: { client, name }, @@ -290,60 +401,77 @@ async function callMCPTool({ source: { type: 'base64', data: String(item.data), - media_type: item.mimeType as ImageBlockParam.Source['media_type'], + media_type: normalizeImageMediaType(item.mimeType), }, } } return item - }) + }) as ToolResultBlockParam['content'] } throw Error(`Unexpected response format from tool ${tool}`) } finally { + merged?.cleanup() timeoutSignal?.cleanup() } } -export const getMCPCommands = memoize(async (): Promise => { - const results = await requestAll< - ListPromptsResult, - typeof ListPromptsResultSchema - >( - { - method: 'prompts/list', - }, - ListPromptsResultSchema, - 'prompts', - ) +export const getMCPCommands: CacheClearableFn = Object.assign( + async (): Promise => { + const clients = await getConnectedClients() + const results = await Promise.all( + clients.map(async client => ({ + client, + result: await getCachedServerDiscovery({ + client, + cache: promptDiscoveryCache, + req: { method: 'prompts/list' }, + resultSchema: ListPromptsResultSchema, + requiredCapability: 'prompts', + }), + })), + ) - return results.flatMap(({ client, result }) => - result.prompts?.map(_ => { - const serverPart = sanitizeMcpIdentifierPart(client.name) - const argNames = Object.values(_.arguments ?? {}).map(k => k.name) - return { - type: 'prompt', - name: `mcp__${serverPart}__${_.name}`, - description: _.description ?? '', - isEnabled: true, - isHidden: false, - progressMessage: 'running', - userFacingName() { - const title = - typeof (_ as any).title === 'string' ? (_ as any).title : _.name - return `${client.name}:${title} (MCP)` - }, - argNames, - async getPromptForCommand(args: string) { - const argsArray = args.split(' ') - return await runCommand( - { name: _.name, client }, - zipObject(argNames, argsArray), - ) - }, - } satisfies Command - }), - ) -}) + return results.flatMap(({ client, result }) => { + if (!result) return [] + + return result.prompts?.map(prompt => { + const serverPart = sanitizeMcpIdentifierPart(client.name) + const argNames = Object.values(prompt.arguments ?? {}).map(k => k.name) + return { + type: 'prompt', + name: `mcp__${serverPart}__${prompt.name}`, + description: prompt.description ?? '', + isEnabled: true, + isHidden: false, + progressMessage: 'running', + userFacingName() { + const title = + typeof (prompt as any).title === 'string' + ? (prompt as any).title + : prompt.name + return `${client.name}:${title} (MCP)` + }, + argNames, + async getPromptForCommand(args: string) { + const argsArray = args.split(' ') + return await runCommand( + { name: prompt.name, client }, + zipObject(argNames, argsArray), + ) + }, + } satisfies Command + }) + }) + }, + { + cache: { + clear: () => { + promptDiscoveryCache.clear() + }, + }, + }, +) export async function runCommand( { name, client }: { name: string; client: ConnectedClient }, @@ -372,8 +500,7 @@ export async function runCommand( type: 'image', source: { data: String((content as any).data), - media_type: (content as any) - .mimeType as ImageBlockParam.Source['media_type'], + media_type: normalizeImageMediaType((content as any).mimeType), type: 'base64', }, }, diff --git a/src/tools/mcp/ListMcpResourcesTool/ListMcpResourcesTool.tsx b/src/tools/mcp/ListMcpResourcesTool/ListMcpResourcesTool.tsx index efc27c5e2..4a71482ba 100644 --- a/src/tools/mcp/ListMcpResourcesTool/ListMcpResourcesTool.tsx +++ b/src/tools/mcp/ListMcpResourcesTool/ListMcpResourcesTool.tsx @@ -5,6 +5,7 @@ import { Cost } from '@components/Cost' import { FallbackToolUseRejectedMessage } from '@components/FallbackToolUseRejectedMessage' import type { Tool, ToolUseContext } from '@tool' import { getClients } from '@services/mcpClient' +import { logMCPError } from '@utils/log' import { ListResourcesResultSchema } from '@modelcontextprotocol/sdk/types.js' import { DESCRIPTION, PROMPT, TOOL_NAME } from './prompt' @@ -26,6 +27,7 @@ type OutputItem = { } type Output = OutputItem[] +type ListedResource = Omit export const ListMcpResourcesTool = { name: TOOL_NAME, @@ -118,12 +120,19 @@ export const ListMcpResourcesTool = { ) if (!result.resources) continue resources.push( - ...result.resources.map(r => ({ - ...r, - server: wrapped.name, - })), + ...(result.resources as ListedResource[]).map( + (r: ListedResource) => ({ + ...r, + server: wrapped.name, + }), + ), ) - } catch {} + } catch (error) { + logMCPError( + wrapped.name, + `Failed to list resources: ${error instanceof Error ? error.message : String(error)}`, + ) + } } yield { diff --git a/src/utils/ai/anthropic.ts b/src/utils/ai/anthropic.ts new file mode 100644 index 000000000..63c888944 --- /dev/null +++ b/src/utils/ai/anthropic.ts @@ -0,0 +1,144 @@ +import type { + Base64ImageSource, + ContentBlock, + ContentBlockParam, + TextBlock, + TextBlockParam, + ToolUseBlockParam, + Usage, +} from '@anthropic-ai/sdk/resources/index.mjs' + +export type AnthropicImageMediaType = Base64ImageSource['media_type'] + +export type AnthropicUsage = Usage & { + prompt_tokens?: number + completion_tokens?: number + promptTokens?: number + completionTokens?: number + totalTokens?: number + reasoningTokens?: number +} + +export type ToolUseLikeBlockParam = Omit & { + type: 'tool_use' | 'server_tool_use' | 'mcp_tool_use' +} + +export function createAnthropicUsage( + overrides: Partial = {}, +): AnthropicUsage { + return { + cache_creation: null, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + inference_geo: null, + input_tokens: 0, + output_tokens: 0, + output_tokens_details: null, + server_tool_use: null, + service_tier: null, + ...overrides, + } +} + +export function normalizeAnthropicUsage(usage?: unknown): AnthropicUsage { + if (!usage || typeof usage !== 'object') { + return createAnthropicUsage() + } + + const source = usage as Record + const inputTokens = numberValue( + source.input_tokens, + source.prompt_tokens, + source.inputTokens, + ) + const outputTokens = numberValue( + source.output_tokens, + source.completion_tokens, + source.outputTokens, + ) + const cacheReadInputTokens = numberValue( + source.cache_read_input_tokens, + objectValue(source.prompt_token_details)?.cached_tokens, + source.cacheReadInputTokens, + ) + const cacheCreationInputTokens = numberValue( + source.cache_creation_input_tokens, + source.cacheCreatedInputTokens, + ) + + return createAnthropicUsage({ + ...(source as Partial), + input_tokens: inputTokens, + output_tokens: outputTokens, + cache_read_input_tokens: cacheReadInputTokens, + cache_creation_input_tokens: cacheCreationInputTokens, + }) +} + +export function isTextBlock( + block: unknown, +): block is TextBlock | TextBlockParam { + return ( + !!block && + typeof block === 'object' && + (block as { type?: unknown }).type === 'text' && + typeof (block as { text?: unknown }).text === 'string' + ) +} + +export function extractTextFromContent(content: unknown): string | null { + if (typeof content === 'string') { + return content + } + if (!Array.isArray(content)) { + return null + } + const textBlock = content.find(isTextBlock) + return textBlock?.text ?? null +} + +export function isToolUseLikeBlockParam( + block: unknown, +): block is ToolUseLikeBlockParam { + return ( + !!block && + typeof block === 'object' && + ((block as { type?: unknown }).type === 'tool_use' || + (block as { type?: unknown }).type === 'server_tool_use' || + (block as { type?: unknown }).type === 'mcp_tool_use') + ) +} + +export function normalizeImageMediaType( + mimeType: unknown, +): AnthropicImageMediaType { + switch (mimeType) { + case 'image/jpeg': + case 'image/png': + case 'image/gif': + case 'image/webp': + return mimeType + default: + return 'image/png' + } +} + +export type AnthropicContentBlockLike = + | ContentBlock + | ContentBlockParam + | ToolUseLikeBlockParam + +function objectValue(value: unknown): Record | null { + return value && typeof value === 'object' + ? (value as Record) + : null +} + +function numberValue(...values: unknown[]): number { + for (const value of values) { + if (typeof value === 'number' && Number.isFinite(value)) { + return value + } + } + return 0 +} diff --git a/tests/unit/mcp-connection-internals.test.ts b/tests/unit/mcp-connection-internals.test.ts new file mode 100644 index 000000000..863662976 --- /dev/null +++ b/tests/unit/mcp-connection-internals.test.ts @@ -0,0 +1,77 @@ +import { afterEach, describe, expect, test } from 'bun:test' +import { + createMcpTransportCandidates, + getMcpConnectionTimeoutMs, + getMcpServerConnectionBatchSize, +} from '../../src/services/mcp/connection' +import { getClients, getMCPCommands, getMCPTools } from '../../src/services/mcp' + +describe('MCP connection internals', () => { + const originalBatchSize = process.env.MCP_SERVER_CONNECTION_BATCH_SIZE + const originalTimeout = process.env.MCP_CONNECTION_TIMEOUT_MS + + afterEach(() => { + if (originalBatchSize === undefined) + delete process.env.MCP_SERVER_CONNECTION_BATCH_SIZE + else process.env.MCP_SERVER_CONNECTION_BATCH_SIZE = originalBatchSize + + if (originalTimeout === undefined) + delete process.env.MCP_CONNECTION_TIMEOUT_MS + else process.env.MCP_CONNECTION_TIMEOUT_MS = originalTimeout + }) + + test('preserves transport fallback ordering for HTTP and SSE configs', async () => { + const sseCandidates = await createMcpTransportCandidates({ + type: 'sse', + url: 'http://127.0.0.1:3999/mcp', + headers: { Authorization: 'Bearer token' }, + }) + expect(sseCandidates.map(candidate => candidate.kind)).toEqual([ + 'sse', + 'http', + ]) + + const httpCandidates = await createMcpTransportCandidates({ + type: 'http', + url: 'http://127.0.0.1:3999/mcp', + headers: { Authorization: 'Bearer token' }, + }) + expect(httpCandidates.map(candidate => candidate.kind)).toEqual([ + 'http', + 'sse', + ]) + }) + + test('uses stdio as a single transport candidate by default', async () => { + const candidates = await createMcpTransportCandidates({ + command: process.execPath, + args: ['--version'], + env: { TEST_ENV: '1' }, + }) + + expect(candidates.map(candidate => candidate.kind)).toEqual(['stdio']) + }) + + test('parses connection env vars without changing defaults', () => { + delete process.env.MCP_SERVER_CONNECTION_BATCH_SIZE + delete process.env.MCP_CONNECTION_TIMEOUT_MS + expect(getMcpServerConnectionBatchSize()).toBe(3) + expect(getMcpConnectionTimeoutMs()).toBe(30_000) + + process.env.MCP_SERVER_CONNECTION_BATCH_SIZE = '7' + process.env.MCP_CONNECTION_TIMEOUT_MS = '1234' + expect(getMcpServerConnectionBatchSize()).toBe(7) + expect(getMcpConnectionTimeoutMs()).toBe(1234) + + process.env.MCP_SERVER_CONNECTION_BATCH_SIZE = '0' + process.env.MCP_CONNECTION_TIMEOUT_MS = 'not-a-number' + expect(getMcpServerConnectionBatchSize()).toBe(3) + expect(getMcpConnectionTimeoutMs()).toBe(30_000) + }) + + test('preserves cache.clear compatibility shims on public getters', () => { + expect(typeof (getClients as any).cache?.clear).toBe('function') + expect(typeof (getMCPTools as any).cache?.clear).toBe('function') + expect(typeof (getMCPCommands as any).cache?.clear).toBe('function') + }) +}) diff --git a/tests/unit/mcp-content-normalization.test.ts b/tests/unit/mcp-content-normalization.test.ts new file mode 100644 index 000000000..9874aac8f --- /dev/null +++ b/tests/unit/mcp-content-normalization.test.ts @@ -0,0 +1,108 @@ +import { describe, expect, test } from 'bun:test' +import { runCommand } from '@services/mcp/tools-integration' + +describe('MCP content normalization', () => { + test('runCommand converts MCP image prompt content to Anthropic image blocks', async () => { + const client: any = { + name: 'fixture', + client: { + async getPrompt() { + return { + messages: [ + { + role: 'user', + content: { + type: 'image', + data: 'abc123', + mimeType: 'image/jpeg', + }, + }, + ], + } + }, + }, + } + + const messages = await runCommand( + { name: 'screenshot', client }, + {}, + ) + + expect(messages).toEqual([ + { + role: 'user', + content: [ + { + type: 'image', + source: { + type: 'base64', + data: 'abc123', + media_type: 'image/jpeg', + }, + }, + ], + }, + ]) + }) + + test('runCommand falls back to png when MCP image mime type is absent', async () => { + const client: any = { + name: 'fixture', + client: { + async getPrompt() { + return { + messages: [ + { + role: 'assistant', + content: { + type: 'image', + data: 'abc123', + }, + }, + ], + } + }, + }, + } + + const messages = await runCommand( + { name: 'screenshot', client }, + {}, + ) + + expect((messages[0]!.content as any[])[0].source.media_type).toBe( + 'image/png', + ) + }) + + test('runCommand falls back to png when MCP image mime type is unsupported', async () => { + const client: any = { + name: 'fixture', + client: { + async getPrompt() { + return { + messages: [ + { + role: 'user', + content: { + type: 'image', + data: 'abc123', + mimeType: 'application/octet-stream', + }, + }, + ], + } + }, + }, + } + + const messages = await runCommand( + { name: 'screenshot', client }, + {}, + ) + + expect((messages[0]!.content as any[])[0].source.media_type).toBe( + 'image/png', + ) + }) +}) diff --git a/tests/unit/mcp-manager-lifecycle.test.ts b/tests/unit/mcp-manager-lifecycle.test.ts new file mode 100644 index 000000000..7d203aafc --- /dev/null +++ b/tests/unit/mcp-manager-lifecycle.test.ts @@ -0,0 +1,194 @@ +import { beforeEach, describe, expect, mock, test } from 'bun:test' +import type { WrappedClient } from '../../src/services/mcp/connection' + +function createMockSdkClient() { + return { + ping: mock(async () => {}), + close: mock(async () => {}), + getServerCapabilities: () => null, + request: mock(async () => ({})), + setNotificationHandler: mock(() => {}), + } as any +} + +const sdkClientsByName = new Map>() + +function getOrCreateSdkClient(name: string) { + let client = sdkClientsByName.get(name) + if (!client) { + client = createMockSdkClient() + sdkClientsByName.set(name, client) + } + return client +} + +const mockConnectMcpServer = mock( + async (name: string): Promise => ({ + name, + client: getOrCreateSdkClient(name), + capabilities: null, + type: 'connected', + }), +) + +mock.module('../../src/services/mcp/connection', () => ({ + connectMcpServer: mockConnectMcpServer, + captureMcpCapabilities: () => null, + getMcpConnectionTimeoutMs: () => 5_000, + getMcpServerConnectionBatchSize: () => 3, +})) + +const { MCPClientManager } = await import('../../src/services/mcp/manager') + +const stdioServer = { command: 'echo', args: [] } as any + +function staleHealthCheck(manager: any, name: string) { + const entry = (manager as any).clients.get(name) + if (entry) entry.lastHealthCheckAt = 0 +} + +describe('MCPClientManager', () => { + beforeEach(() => { + mockConnectMcpServer.mockClear() + sdkClientsByName.clear() + }) + + test('connects to new servers', async () => { + const manager = new MCPClientManager() + + const results = await manager.getClientsForServers({ alpha: stdioServer }) + + expect(results).toHaveLength(1) + expect(results[0]!.type).toBe('connected') + expect(results[0]!.name).toBe('alpha') + expect(mockConnectMcpServer).toHaveBeenCalledTimes(1) + }) + + test('reuses connection when health check is not yet due', async () => { + const manager = new MCPClientManager() + const servers = { alpha: stdioServer } + + const first = await manager.getClientsForServers(servers) + const second = await manager.getClientsForServers(servers) + + expect(mockConnectMcpServer).toHaveBeenCalledTimes(1) + expect(second[0]!.client).toBe(first[0]!.client) + }) + + test('reconnects when ping fails', async () => { + const manager = new MCPClientManager() + const servers = { alpha: stdioServer } + + await manager.getClientsForServers(servers) + + const oldClient = sdkClientsByName.get('alpha')! + oldClient.ping = mock(async () => { + throw new Error('ping timeout') + }) + + const newClient = createMockSdkClient() + sdkClientsByName.set('alpha', newClient) + + staleHealthCheck(manager, 'alpha') + + const results = await manager.getClientsForServers(servers) + + expect(mockConnectMcpServer).toHaveBeenCalledTimes(2) + expect(oldClient.close).toHaveBeenCalledTimes(1) + expect(results[0]!.client).toBe(newClient) + }) + + test('closes removed servers by default (closeMissing=true)', async () => { + const manager = new MCPClientManager() + + await manager.getClientsForServers({ + alpha: stdioServer, + beta: stdioServer, + }) + + const betaClient = sdkClientsByName.get('beta')! + + await manager.getClientsForServers({ alpha: stdioServer }) + + expect(betaClient.close).toHaveBeenCalledTimes(1) + }) + + test('keeps removed servers when closeMissing=false', async () => { + const manager = new MCPClientManager() + + await manager.getClientsForServers({ + alpha: stdioServer, + beta: stdioServer, + }) + + const betaClient = sdkClientsByName.get('beta')! + + await manager.getClientsForServers( + { alpha: stdioServer }, + { closeMissing: false }, + ) + + expect(betaClient.close).not.toHaveBeenCalled() + }) + + test('reconnects when server config changes', async () => { + const manager = new MCPClientManager() + + await manager.getClientsForServers({ + alpha: { command: 'echo', args: ['v1'] } as any, + }) + + const oldClient = sdkClientsByName.get('alpha')! + sdkClientsByName.delete('alpha') + + await manager.getClientsForServers({ + alpha: { command: 'echo', args: ['v2'] } as any, + }) + + expect(oldClient.close).toHaveBeenCalledTimes(1) + expect(mockConnectMcpServer).toHaveBeenCalledTimes(2) + }) + + test('clear() closes all connections', async () => { + const manager = new MCPClientManager() + + await manager.getClientsForServers({ + alpha: stdioServer, + beta: stdioServer, + }) + + const alphaClient = sdkClientsByName.get('alpha')! + const betaClient = sdkClientsByName.get('beta')! + + manager.clear() + + expect(alphaClient.close).toHaveBeenCalledTimes(1) + expect(betaClient.close).toHaveBeenCalledTimes(1) + }) + + test('returns failed type when connection fails', async () => { + const manager = new MCPClientManager() + mockConnectMcpServer.mockImplementationOnce(async () => ({ + name: 'alpha', + type: 'failed' as const, + })) + + const results = await manager.getClientsForServers({ alpha: stdioServer }) + + expect(results).toHaveLength(1) + expect(results[0]!.type).toBe('failed') + }) + + test('does not retry failed server within FAILED_RETRY_INTERVAL_MS', async () => { + const manager = new MCPClientManager() + mockConnectMcpServer.mockImplementation(async () => ({ + name: 'alpha', + type: 'failed' as const, + })) + + await manager.getClientsForServers({ alpha: stdioServer }) + await manager.getClientsForServers({ alpha: stdioServer }) + + expect(mockConnectMcpServer).toHaveBeenCalledTimes(1) + }) +}) From 60e15f732daac2f5a45179354fe2b77b3567ad26 Mon Sep 17 00:00:00 2001 From: im10furry <1936409761@qq.com> Date: Tue, 9 Jun 2026 14:14:59 +0800 Subject: [PATCH 2/4] style: fix prettier formatting in mcp files --- src/services/mcp/tools-integration.ts | 5 ++++- tests/unit/mcp-content-normalization.test.ts | 15 +++------------ tests/unit/mcp-manager-lifecycle.test.ts | 5 ++++- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/services/mcp/tools-integration.ts b/src/services/mcp/tools-integration.ts index ad6c2edcf..d2858bdfb 100644 --- a/src/services/mcp/tools-integration.ts +++ b/src/services/mcp/tools-integration.ts @@ -4,7 +4,10 @@ import { MCPTool } from '@tools/mcp/MCPTool/MCPTool' import { logMCPError } from '@utils/log' import { debug } from '@utils/log/debugLogger' import type { Command } from '@commands' -import type { MessageParam, ToolResultBlockParam } from '@anthropic-ai/sdk/resources/index.mjs' +import type { + MessageParam, + ToolResultBlockParam, +} from '@anthropic-ai/sdk/resources/index.mjs' import { CallToolResultSchema, PromptListChangedNotificationSchema, diff --git a/tests/unit/mcp-content-normalization.test.ts b/tests/unit/mcp-content-normalization.test.ts index 9874aac8f..c52b1b640 100644 --- a/tests/unit/mcp-content-normalization.test.ts +++ b/tests/unit/mcp-content-normalization.test.ts @@ -23,10 +23,7 @@ describe('MCP content normalization', () => { }, } - const messages = await runCommand( - { name: 'screenshot', client }, - {}, - ) + const messages = await runCommand({ name: 'screenshot', client }, {}) expect(messages).toEqual([ { @@ -65,10 +62,7 @@ describe('MCP content normalization', () => { }, } - const messages = await runCommand( - { name: 'screenshot', client }, - {}, - ) + const messages = await runCommand({ name: 'screenshot', client }, {}) expect((messages[0]!.content as any[])[0].source.media_type).toBe( 'image/png', @@ -96,10 +90,7 @@ describe('MCP content normalization', () => { }, } - const messages = await runCommand( - { name: 'screenshot', client }, - {}, - ) + const messages = await runCommand({ name: 'screenshot', client }, {}) expect((messages[0]!.content as any[])[0].source.media_type).toBe( 'image/png', diff --git a/tests/unit/mcp-manager-lifecycle.test.ts b/tests/unit/mcp-manager-lifecycle.test.ts index 7d203aafc..818ae21f5 100644 --- a/tests/unit/mcp-manager-lifecycle.test.ts +++ b/tests/unit/mcp-manager-lifecycle.test.ts @@ -11,7 +11,10 @@ function createMockSdkClient() { } as any } -const sdkClientsByName = new Map>() +const sdkClientsByName = new Map< + string, + ReturnType +>() function getOrCreateSdkClient(name: string) { let client = sdkClientsByName.get(name) From 5ac65e629c4def4275d58653382e048e904f5434 Mon Sep 17 00:00:00 2001 From: im10furry <1936409761@qq.com> Date: Tue, 9 Jun 2026 14:36:13 +0800 Subject: [PATCH 3/4] fix: adapt types for @anthropic-ai/sdk@0.102 and diff@9 - Use normalizeAnthropicUsage/createAnthropicUsage for Usage type compliance - Replace Hunk import with StructuredPatchHunk from diff@9 - Remove @anthropic-ai/sdk/shims/node import (no longer exists) - Use normalizeImageMediaType for safe image content type conversion - Add ToolUseLikeBlockParam type for mcp_tool_use/server_tool_use handling - Fix ContentBlock text property access with proper type narrowing - Fix ToolUseBlock caller property and Message container/stop_details fields --- src/app/binaryFeedback.ts | 24 +++++---- src/commands/agents/generation.ts | 13 +---- src/commands/compact.ts | 8 ++- src/entrypoints/cli.tsx | 3 -- .../ai/adapters/responsesStreaming.ts | 12 ++++- src/services/ai/llm.ts | 54 ++++++++----------- .../filesystem/FileEditTool/FileEditTool.tsx | 4 +- src/tools/filesystem/FileEditTool/utils.ts | 4 +- .../filesystem/FileReadTool/FileReadTool.tsx | 20 +++---- .../FileWriteTool/FileWriteTool.tsx | 4 +- .../network/WebFetchTool/WebFetchTool.tsx | 4 +- .../components/FileEditToolUpdatedMessage.tsx | 4 +- src/ui/components/Message.tsx | 3 +- src/ui/components/StructuredDiff.tsx | 4 +- .../messages/AssistantToolUseMessage.tsx | 4 +- .../user-tool-result-message/utils.tsx | 18 +++---- src/utils/messages/core.ts | 10 ++-- src/utils/messages/userInput.tsx | 9 +++- .../permissions/bashToolPermissionEngine.ts | 3 +- src/utils/session/autoCompactCore.ts | 8 ++- src/utils/session/messageContextManager.ts | 10 ++++ src/utils/text/diff.ts | 4 +- 22 files changed, 117 insertions(+), 110 deletions(-) diff --git a/src/app/binaryFeedback.ts b/src/app/binaryFeedback.ts index ef6a100e2..f756a0a62 100644 --- a/src/app/binaryFeedback.ts +++ b/src/app/binaryFeedback.ts @@ -1,4 +1,8 @@ -import { TextBlock, ToolUseBlock } from '@anthropic-ai/sdk/resources/index.mjs' +import type { + ContentBlock, + TextBlock, + ToolUseBlock, +} from '@anthropic-ai/sdk/resources/index.mjs' import type { AssistantMessage, BinaryFeedbackResult } from './query' import { isEqual, zip } from 'lodash-es' @@ -30,23 +34,25 @@ function textContentBlocksEqual(cb1: TextBlock, cb2: TextBlock): boolean { return cb1.text === cb2.text } -function contentBlocksEqual( - cb1: TextBlock | ToolUseBlock, - cb2: TextBlock | ToolUseBlock, -): boolean { +function contentBlocksEqual(cb1: ContentBlock, cb2: ContentBlock): boolean { if (cb1.type !== cb2.type) { return false } if (cb1.type === 'text') { return textContentBlocksEqual(cb1, cb2 as TextBlock) } - cb2 = cb2 as ToolUseBlock - return cb1.name === cb2.name && isEqual(cb1.input, cb2.input) + if (cb1.type === 'tool_use') { + const toolUseBlock = cb2 as ToolUseBlock + return ( + cb1.name === toolUseBlock.name && isEqual(cb1.input, toolUseBlock.input) + ) + } + return isEqual(cb1, cb2) } function allContentBlocksEqual( - content1: (TextBlock | ToolUseBlock)[], - content2: (TextBlock | ToolUseBlock)[], + content1: ContentBlock[], + content2: ContentBlock[], ): boolean { if (content1.length !== content2.length) { return false diff --git a/src/commands/agents/generation.ts b/src/commands/agents/generation.ts index ca1fde43e..d95cd0476 100644 --- a/src/commands/agents/generation.ts +++ b/src/commands/agents/generation.ts @@ -2,6 +2,7 @@ import { randomUUID } from 'crypto' import type { AgentConfig } from '@utils/agent/loader' import { debug as debugLogger } from '@utils/log/debugLogger' import { logError } from '@utils/log' +import { extractTextFromContent } from '@utils/ai/anthropic' export type GeneratedAgent = { identifier: string @@ -33,17 +34,7 @@ Make the agent highly specialized and effective for the described use case.` ] as any const response = await queryModel('main', messages, [systemPrompt]) - let responseText = '' - if (typeof response.message?.content === 'string') { - responseText = response.message.content - } else if (Array.isArray(response.message?.content)) { - const textContent = response.message.content.find( - (c: any) => c.type === 'text', - ) - responseText = textContent?.text || '' - } else if (response.message?.content?.[0]?.text) { - responseText = response.message.content[0].text - } + const responseText = extractTextFromContent(response.message?.content) ?? '' if (!responseText) { throw new Error('No text content in model response') diff --git a/src/commands/compact.ts b/src/commands/compact.ts index a236251ec..60cf92345 100644 --- a/src/commands/compact.ts +++ b/src/commands/compact.ts @@ -9,6 +9,7 @@ import { getCodeStyle } from '@utils/config/style' import { clearTerminal } from '@utils/terminal' import { resetReminderSession } from '@services/systemReminder' import { resetFileFreshnessSession } from '@services/fileFreshness' +import { createAnthropicUsage } from '@utils/ai/anthropic' const COMPRESSION_PROMPT = `Please provide a comprehensive summary of our conversation structured as follows: @@ -88,12 +89,9 @@ const compact = { throw new Error(summary) } - summaryResponse.message.usage = { - input_tokens: 0, + summaryResponse.message.usage = createAnthropicUsage({ output_tokens: summaryResponse.message.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } + }) await clearTerminal() getMessagesSetter()([]) diff --git a/src/entrypoints/cli.tsx b/src/entrypoints/cli.tsx index 6578d46dd..50abaadc5 100644 --- a/src/entrypoints/cli.tsx +++ b/src/entrypoints/cli.tsx @@ -11,8 +11,5 @@ initSentry() ensurePackagedRuntimeEnv() ensureYogaWasmPath(import.meta.url) -import * as dontcare from '@anthropic-ai/sdk/shims/node' -Object.keys(dontcare) - installProcessHandlers() void runCli() diff --git a/src/services/ai/adapters/responsesStreaming.ts b/src/services/ai/adapters/responsesStreaming.ts index c4e8b1222..789b56c2f 100644 --- a/src/services/ai/adapters/responsesStreaming.ts +++ b/src/services/ai/adapters/responsesStreaming.ts @@ -1,6 +1,7 @@ import { StreamingEvent } from './base' import { AssistantMessage } from '@query' import { setRequestStatus } from '@utils/session/requestStatus' +import { createAnthropicUsage } from '@utils/ai/anthropic' export async function processResponsesStream( stream: AsyncGenerator, @@ -74,9 +75,16 @@ export async function processResponsesStream( const assistantMessage: AssistantMessage = { type: 'assistant', message: { + id: responseId, + container: null, + model: '', role: 'assistant', content: contentBlocks, - usage: { + stop_details: null, + stop_reason: 'end_turn', + stop_sequence: null, + type: 'message', + usage: createAnthropicUsage({ input_tokens: usage.prompt_tokens ?? 0, output_tokens: usage.completion_tokens ?? 0, prompt_tokens: usage.prompt_tokens ?? 0, @@ -85,7 +93,7 @@ export async function processResponsesStream( usage.totalTokens ?? (usage.prompt_tokens || 0) + (usage.completion_tokens || 0), reasoningTokens: usage.reasoningTokens, - }, + }), }, costUSD: 0, durationMs: Date.now() - startTime, diff --git a/src/services/ai/llm.ts b/src/services/ai/llm.ts index 8e0c30d14..1dd67f718 100644 --- a/src/services/ai/llm.ts +++ b/src/services/ai/llm.ts @@ -1,4 +1,3 @@ -import '@anthropic-ai/sdk/shims/node' import Anthropic, { APIConnectionError, APIError } from '@anthropic-ai/sdk' import { StreamingEvent } from './adapters/base' import { AnthropicBedrock } from '@anthropic-ai/bedrock-sdk' @@ -74,6 +73,10 @@ import { NO_CONTENT_MESSAGE, PROMPT_TOO_LONG_ERROR_MESSAGE, } from './llmConstants' +import { + createAnthropicUsage, + normalizeAnthropicUsage, +} from '@utils/ai/anthropic' function isGPT5Model(modelName: string): boolean { return modelName.startsWith('gpt-5') @@ -613,6 +616,7 @@ function convertOpenAIResponseToAnthropic( input: toolArgs, name: toolName, id: toolCall.id?.length > 0 ? toolCall.id : nanoid(), + caller: { type: 'direct' }, }) } } @@ -1417,6 +1421,7 @@ async function queryAnthropicNative( return { type: 'text' as const, text: block.text, + citations: block.citations ?? null, } } else if (block.type === 'tool_use') { return { @@ -1424,6 +1429,7 @@ async function queryAnthropicNative( id: block.id, name: block.name, input: block.input, + caller: block.caller, } } return block @@ -1432,9 +1438,11 @@ async function queryAnthropicNative( const assistantMessage: AssistantMessage = { message: { id: response.id, + container: (response as any).container ?? null, content, model: response.model, role: 'assistant', + stop_details: (response as any).stop_details ?? null, stop_reason: response.stop_reason, stop_sequence: response.stop_sequence, type: 'message', @@ -1909,6 +1917,7 @@ function buildAssistantMessageFromUnifiedResponse( input: toolArgs, name: toolName, id: toolCall.id?.length > 0 ? toolCall.id : nanoid(), + caller: { type: 'direct' }, }) } } @@ -1916,9 +1925,19 @@ function buildAssistantMessageFromUnifiedResponse( return { type: 'assistant', message: { + id: + unifiedResponse.responseId ?? + unifiedResponse.id ?? + `resp_${Date.now()}`, + container: null, + model: unifiedResponse.model ?? '', role: 'assistant', content: contentBlocks, - usage: { + stop_details: null, + stop_reason: unifiedResponse.stopReason ?? 'end_turn', + stop_sequence: null, + type: 'message', + usage: createAnthropicUsage({ input_tokens: unifiedResponse.usage?.promptTokens ?? unifiedResponse.usage?.input_tokens ?? @@ -1951,7 +1970,7 @@ function buildAssistantMessageFromUnifiedResponse( (unifiedResponse.usage?.completionTokens ?? unifiedResponse.usage?.output_tokens ?? 0), - }, + }), }, costUSD: 0, durationMs: Date.now() - startTime, @@ -1961,34 +1980,7 @@ function buildAssistantMessageFromUnifiedResponse( } function normalizeUsage(usage?: any) { - if (!usage) { - return { - input_tokens: 0, - output_tokens: 0, - cache_read_input_tokens: 0, - cache_creation_input_tokens: 0, - } - } - - const inputTokens = - usage.input_tokens ?? usage.prompt_tokens ?? usage.inputTokens ?? 0 - const outputTokens = - usage.output_tokens ?? usage.completion_tokens ?? usage.outputTokens ?? 0 - const cacheReadInputTokens = - usage.cache_read_input_tokens ?? - usage.prompt_token_details?.cached_tokens ?? - usage.cacheReadInputTokens ?? - 0 - const cacheCreationInputTokens = - usage.cache_creation_input_tokens ?? usage.cacheCreatedInputTokens ?? 0 - - return { - ...usage, - input_tokens: inputTokens, - output_tokens: outputTokens, - cache_read_input_tokens: cacheReadInputTokens, - cache_creation_input_tokens: cacheCreationInputTokens, - } + return normalizeAnthropicUsage(usage) } function getModelInputTokenCostUSD(model: string): number { diff --git a/src/tools/filesystem/FileEditTool/FileEditTool.tsx b/src/tools/filesystem/FileEditTool/FileEditTool.tsx index ea72abc50..d2a7377db 100644 --- a/src/tools/filesystem/FileEditTool/FileEditTool.tsx +++ b/src/tools/filesystem/FileEditTool/FileEditTool.tsx @@ -1,4 +1,4 @@ -import { Hunk } from 'diff' +import type { StructuredPatchHunk } from 'diff' import { mkdirSync, readFileSync, statSync } from 'fs' import { Box, Text } from 'ink' import { dirname, isAbsolute, relative, resolve, sep } from 'path' @@ -355,7 +355,7 @@ ${addLineNumbers({ oldString: string newString: string originalFile: string - structuredPatch: Hunk[] + structuredPatch: StructuredPatchHunk[] } > diff --git a/src/tools/filesystem/FileEditTool/utils.ts b/src/tools/filesystem/FileEditTool/utils.ts index a4e7a074b..2492ba779 100644 --- a/src/tools/filesystem/FileEditTool/utils.ts +++ b/src/tools/filesystem/FileEditTool/utils.ts @@ -1,7 +1,7 @@ import { isAbsolute, resolve } from 'path' import { getCwd } from '@utils/state' import { readFileBun } from '@utils/bun/file' -import { type Hunk } from 'diff' +import { type StructuredPatchHunk } from 'diff' import { getPatch } from '@utils/text/diff' import { normalizeLineEndings } from '@utils/terminal/paste' @@ -10,7 +10,7 @@ export async function applyEdit( old_string: string, new_string: string, replace_all = false, -): Promise<{ patch: Hunk[]; updatedFile: string }> { +): Promise<{ patch: StructuredPatchHunk[]; updatedFile: string }> { const fullFilePath = isAbsolute(file_path) ? file_path : resolve(getCwd(), file_path) diff --git a/src/tools/filesystem/FileReadTool/FileReadTool.tsx b/src/tools/filesystem/FileReadTool/FileReadTool.tsx index c690fbc2c..30c9a7fb1 100644 --- a/src/tools/filesystem/FileReadTool/FileReadTool.tsx +++ b/src/tools/filesystem/FileReadTool/FileReadTool.tsx @@ -1,7 +1,4 @@ -import { - DocumentBlockParam, - ImageBlockParam, -} from '@anthropic-ai/sdk/resources/index.mjs' +import { DocumentBlockParam } from '@anthropic-ai/sdk/resources/index.mjs' import { statSync } from 'fs' import { Box, Text } from 'ink' import * as path from 'node:path' @@ -29,6 +26,10 @@ import { DESCRIPTION, PROMPT } from './prompt' import { hasReadPermission } from '@utils/permissions/filesystem' import { secureFileService } from '@utils/fs/secureFile' import { readFileBun, fileExistsBun, getFileSizeBun } from '@utils/bun/file' +import { + type AnthropicImageMediaType, + normalizeImageMediaType, +} from '@utils/ai/anthropic' const MAX_LINES_TO_RENDER = 5 const MAX_LINE_LENGTH = 2000 @@ -432,7 +433,7 @@ export const FileReadTool = { type: 'image' file: { base64: string - type: ImageBlockParam.Source['media_type'] + type: AnthropicImageMediaType originalSize: number } } @@ -454,18 +455,19 @@ function createImageResponse( type: 'image' file: { base64: string - type: ImageBlockParam.Source['media_type'] + type: AnthropicImageMediaType originalSize: number } } { - const normalized: ImageBlockParam.Source['media_type'] = + const normalized = normalizeImageMediaType( ext === '.jpg' || ext === '.jpeg' ? 'image/jpeg' : ext === '.png' ? 'image/png' : ext === '.gif' ? 'image/gif' - : 'image/webp' + : 'image/webp', + ) return { type: 'image', file: { @@ -483,7 +485,7 @@ async function readImage( type: 'image' file: { base64: string - type: ImageBlockParam.Source['media_type'] + type: AnthropicImageMediaType originalSize: number } }> { diff --git a/src/tools/filesystem/FileWriteTool/FileWriteTool.tsx b/src/tools/filesystem/FileWriteTool/FileWriteTool.tsx index dd08bdfaf..76220d540 100644 --- a/src/tools/filesystem/FileWriteTool/FileWriteTool.tsx +++ b/src/tools/filesystem/FileWriteTool/FileWriteTool.tsx @@ -1,4 +1,4 @@ -import { Hunk } from 'diff' +import type { StructuredPatchHunk } from 'diff' import { mkdirSync, readFileSync, statSync } from 'fs' import { Box, Text } from 'ink' import { EOL } from 'os' @@ -305,6 +305,6 @@ ${addLineNumbers({ type: 'create' | 'update' filePath: string content: string - structuredPatch: Hunk[] + structuredPatch: StructuredPatchHunk[] } > diff --git a/src/tools/network/WebFetchTool/WebFetchTool.tsx b/src/tools/network/WebFetchTool/WebFetchTool.tsx index fee0ee4fe..9f61347d8 100644 --- a/src/tools/network/WebFetchTool/WebFetchTool.tsx +++ b/src/tools/network/WebFetchTool/WebFetchTool.tsx @@ -8,6 +8,7 @@ import { queryQuick } from '@services/llmLazy' import { PROMPT, TOOL_NAME_FOR_PROMPT } from './prompt' import { convertHtmlToMarkdown } from './htmlToMarkdown' import { urlCache } from './cache' +import { extractTextFromContent } from '@utils/ai/anthropic' const inputSchema = z.strictObject({ url: z.string().url().describe('The URL to fetch content from'), @@ -399,7 +400,8 @@ To complete your request, I need to fetch content from the redirected URL. Pleas }) const result = - aiResponse.message.content[0]?.text || 'No response from model' + extractTextFromContent(aiResponse.message.content) || + 'No response from model' const output: Output = { bytes, diff --git a/src/ui/components/FileEditToolUpdatedMessage.tsx b/src/ui/components/FileEditToolUpdatedMessage.tsx index e4806d80c..116081515 100644 --- a/src/ui/components/FileEditToolUpdatedMessage.tsx +++ b/src/ui/components/FileEditToolUpdatedMessage.tsx @@ -1,4 +1,4 @@ -import { Hunk } from 'diff' +import type { StructuredPatchHunk } from 'diff' import { Box, Text } from 'ink' import * as React from 'react' import { intersperse } from '@utils/text/array' @@ -10,7 +10,7 @@ import { useTerminalSize } from '@hooks/useTerminalSize' type Props = { filePath: string - structuredPatch?: Hunk[] + structuredPatch?: StructuredPatchHunk[] verbose: boolean } diff --git a/src/ui/components/Message.tsx b/src/ui/components/Message.tsx index f8bf6734a..eecf5817a 100644 --- a/src/ui/components/Message.tsx +++ b/src/ui/components/Message.tsx @@ -21,6 +21,7 @@ import { NormalizedMessage } from '@utils/messages' import { AssistantThinkingMessage } from './messages/AssistantThinkingMessage' import { AssistantRedactedThinkingMessage } from './messages/AssistantRedactedThinkingMessage' import { useTerminalSize } from '@hooks/useTerminalSize' +import type { ToolUseLikeBlockParam } from '@utils/ai/anthropic' type Props = { message: UserMessage | AssistantMessage @@ -160,7 +161,7 @@ function AssistantMessage({ | TextBlockParam | ImageBlockParam | ThinkingBlockParam - | ToolUseBlockParam + | ToolUseLikeBlockParam | ToolResultBlockParam costUSD: number durationMs: number diff --git a/src/ui/components/StructuredDiff.tsx b/src/ui/components/StructuredDiff.tsx index 21a8e47f1..08f5f29be 100644 --- a/src/ui/components/StructuredDiff.tsx +++ b/src/ui/components/StructuredDiff.tsx @@ -1,12 +1,12 @@ import { Box, Text } from 'ink' import * as React from 'react' -import { Hunk } from 'diff' +import type { StructuredPatchHunk } from 'diff' import { getTheme, ThemeNames } from '@utils/theme' import { useMemo } from 'react' import { wrapText } from '@utils/terminal/format' type Props = { - patch: Hunk + patch: StructuredPatchHunk dim: boolean width: number overrideTheme?: ThemeNames diff --git a/src/ui/components/messages/AssistantToolUseMessage.tsx b/src/ui/components/messages/AssistantToolUseMessage.tsx index 91396f239..3b15e70e8 100644 --- a/src/ui/components/messages/AssistantToolUseMessage.tsx +++ b/src/ui/components/messages/AssistantToolUseMessage.tsx @@ -1,7 +1,7 @@ import { Box, Text } from 'ink' import React from 'react' import { logError } from '@utils/log' -import { ToolUseBlockParam } from '@anthropic-ai/sdk/resources/index.mjs' +import type { ToolUseLikeBlockParam } from '@utils/ai/anthropic' import { Tool } from '@tool' import { Cost } from '@components/Cost' import { ToolUseLoader } from '@components/ToolUseLoader' @@ -11,7 +11,7 @@ import { TaskToolMessage } from './TaskToolMessage' import { resolveToolNameAlias } from '@utils/tooling/toolNameAliases' type Props = { - param: ToolUseBlockParam + param: ToolUseLikeBlockParam costUSD: number durationMs: number addMargin: boolean diff --git a/src/ui/components/messages/user-tool-result-message/utils.tsx b/src/ui/components/messages/user-tool-result-message/utils.tsx index e33ed0c55..81230cd53 100644 --- a/src/ui/components/messages/user-tool-result-message/utils.tsx +++ b/src/ui/components/messages/user-tool-result-message/utils.tsx @@ -1,15 +1,18 @@ -import { ToolUseBlockParam } from '@anthropic-ai/sdk/resources/index.mjs' import { Message } from '@query' import { useMemo } from 'react' import { Tool } from '@tool' import { GlobTool } from '@tools/GlobTool/GlobTool' import { GrepTool } from '@tools/search/GrepTool/GrepTool' +import { + isToolUseLikeBlockParam, + type ToolUseLikeBlockParam, +} from '@utils/ai/anthropic' function getToolUseFromMessages( toolUseID: string, messages: Message[], -): ToolUseBlockParam | null { - let toolUse: ToolUseBlockParam | null = null +): ToolUseLikeBlockParam | null { + let toolUse: ToolUseLikeBlockParam | null = null for (const message of messages) { if ( message.type !== 'assistant' || @@ -18,12 +21,7 @@ function getToolUseFromMessages( continue } for (const content of message.message.content) { - if ( - (content.type === 'tool_use' || - content.type === 'server_tool_use' || - content.type === 'mcp_tool_use') && - content.id === toolUseID - ) { + if (isToolUseLikeBlockParam(content) && content.id === toolUseID) { toolUse = content } } @@ -35,7 +33,7 @@ export function useGetToolFromMessages( toolUseID: string, tools: Tool[], messages: Message[], -): { tool: Tool; toolUse: ToolUseBlockParam } | null { +): { tool: Tool; toolUse: ToolUseLikeBlockParam } | null { return useMemo(() => { const toolUse = getToolUseFromMessages(toolUseID, messages) if (!toolUse) { diff --git a/src/utils/messages/core.ts b/src/utils/messages/core.ts index b6921e478..1b86e7b2c 100644 --- a/src/utils/messages/core.ts +++ b/src/utils/messages/core.ts @@ -12,6 +12,7 @@ import { ContentBlockParam, ContentBlock, } from '@anthropic-ai/sdk/resources/index.mjs' +import { createAnthropicUsage } from '@utils/ai/anthropic' export const INTERRUPT_MESSAGE = '[Request interrupted by user]' export const INTERRUPT_MESSAGE_FOR_TOOL_USE = @@ -48,17 +49,14 @@ function baseCreateAssistantMessage( uuid: randomUUID(), message: { id: randomUUID(), + container: null, model: '', role: 'assistant', + stop_details: null, stop_reason: 'stop_sequence', stop_sequence: '', type: 'message', - usage: { - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }, + usage: createAnthropicUsage(), content, }, ...extra, diff --git a/src/utils/messages/userInput.tsx b/src/utils/messages/userInput.tsx index 9bcafdcfc..dc0efdb4f 100644 --- a/src/utils/messages/userInput.tsx +++ b/src/utils/messages/userInput.tsx @@ -146,12 +146,17 @@ export async function processUserInput( return [] } + const secondMessageContent = + newMessages[1]?.type === 'assistant' + ? (newMessages[1].message as { content?: unknown }).content + : null + if ( newMessages.length === 2 && newMessages[0]!.type === 'user' && newMessages[1]!.type === 'assistant' && - typeof newMessages[1]!.message.content === 'string' && - newMessages[1]!.message.content.startsWith('Unknown command:') + typeof secondMessageContent === 'string' && + secondMessageContent.startsWith('Unknown command:') ) { return newMessages } diff --git a/src/utils/permissions/bashToolPermissionEngine.ts b/src/utils/permissions/bashToolPermissionEngine.ts index 315f183c6..088a5431c 100644 --- a/src/utils/permissions/bashToolPermissionEngine.ts +++ b/src/utils/permissions/bashToolPermissionEngine.ts @@ -481,7 +481,8 @@ function maybeConsumeRedirection( redirections: Redirection[], outputTokens: ParseEntry[], ): { skip: number } { - const isFd = (v: unknown) => typeof v === 'string' && /^\d+$/.test(v.trim()) + const isFd = (v: unknown): v is string => + typeof v === 'string' && /^\d+$/.test(v.trim()) if (isOpToken(token, '>') || isOpToken(token, '>>')) { const operator = String((token as any).op) as '>' | '>>' diff --git a/src/utils/session/autoCompactCore.ts b/src/utils/session/autoCompactCore.ts index 7259608fd..7b2b21322 100644 --- a/src/utils/session/autoCompactCore.ts +++ b/src/utils/session/autoCompactCore.ts @@ -13,6 +13,7 @@ import { getModelManager } from '@utils/model' import { debug as debugLogger } from '@utils/log/debugLogger' import { logError } from '@utils/log' import { calculateAutoCompactThresholds } from './autoCompactThreshold' +import { createAnthropicUsage } from '@utils/ai/anthropic' async function getMainConversationContextLimit(): Promise { try { @@ -164,12 +165,9 @@ async function executeAutoCompact( ) } - summaryResponse.message.usage = { - input_tokens: 0, + summaryResponse.message.usage = createAnthropicUsage({ output_tokens: summaryResponse.message.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } + }) const recoveredFiles = await selectAndReadFiles() diff --git a/src/utils/session/messageContextManager.ts b/src/utils/session/messageContextManager.ts index 0e2a61ec1..31d8be6f8 100644 --- a/src/utils/session/messageContextManager.ts +++ b/src/utils/session/messageContextManager.ts @@ -2,6 +2,7 @@ import { Message } from '@query' import type { UUID } from '@kode-types/common' import { countTokens } from '@utils/model/tokens' import crypto from 'crypto' +import { createAnthropicUsage } from '@utils/ai/anthropic' export interface MessageRetentionStrategy { type: @@ -116,11 +117,20 @@ export class MessageContextManager { const summaryMessage: Message = { type: 'assistant', message: { + id: crypto.randomUUID(), + container: null, + model: '', role: 'assistant', + stop_details: null, + stop_reason: 'stop_sequence', + stop_sequence: '', + type: 'message', + usage: createAnthropicUsage(), content: [ { type: 'text', text: `[CONVERSATION SUMMARY - ${olderMessages.length} messages compressed]\n\n${summary}\n\n[END SUMMARY - Recent context follows...]`, + citations: [], }, ], }, diff --git a/src/utils/text/diff.ts b/src/utils/text/diff.ts index 6c850a5cb..2d7b94e49 100644 --- a/src/utils/text/diff.ts +++ b/src/utils/text/diff.ts @@ -1,4 +1,4 @@ -import { type Hunk, structuredPatch } from 'diff' +import { type StructuredPatchHunk, structuredPatch } from 'diff' const CONTEXT_LINES = 3 @@ -16,7 +16,7 @@ export function getPatch({ fileContents: string oldStr: string newStr: string -}): Hunk[] { +}): StructuredPatchHunk[] { return structuredPatch( filePath, filePath, From 165595fdfc174fe23ba6599b1e635bd3ed2ecff2 Mon Sep 17 00:00:00 2001 From: im10furry <1936409761@qq.com> Date: Tue, 9 Jun 2026 14:48:27 +0800 Subject: [PATCH 4/4] fix(mcp): use dependency injection to avoid mock.module cross-test pollution --- src/services/mcp/manager.ts | 7 +++++- tests/unit/mcp-manager-lifecycle.test.ts | 28 +++++++++--------------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/services/mcp/manager.ts b/src/services/mcp/manager.ts index 59498c53c..423d626c5 100644 --- a/src/services/mcp/manager.ts +++ b/src/services/mcp/manager.ts @@ -63,6 +63,11 @@ async function pingWrappedClient(client: WrappedClient): Promise { export class MCPClientManager { private readonly clients = new Map() + private readonly connector: typeof connectMcpServer + + constructor(connector?: typeof connectMcpServer) { + this.connector = connector ?? connectMcpServer + } async getClientsForServers( servers: Record, @@ -155,7 +160,7 @@ export class MCPClientManager { } } - const wrapped = await connectMcpServer(name, serverRef, { + const wrapped = await this.connector(name, serverRef, { clientVersion: options?.clientVersion, }) this.clients.set(name, { diff --git a/tests/unit/mcp-manager-lifecycle.test.ts b/tests/unit/mcp-manager-lifecycle.test.ts index 818ae21f5..095b36ef6 100644 --- a/tests/unit/mcp-manager-lifecycle.test.ts +++ b/tests/unit/mcp-manager-lifecycle.test.ts @@ -1,5 +1,6 @@ import { beforeEach, describe, expect, mock, test } from 'bun:test' import type { WrappedClient } from '../../src/services/mcp/connection' +import { MCPClientManager } from '../../src/services/mcp/manager' function createMockSdkClient() { return { @@ -34,15 +35,6 @@ const mockConnectMcpServer = mock( }), ) -mock.module('../../src/services/mcp/connection', () => ({ - connectMcpServer: mockConnectMcpServer, - captureMcpCapabilities: () => null, - getMcpConnectionTimeoutMs: () => 5_000, - getMcpServerConnectionBatchSize: () => 3, -})) - -const { MCPClientManager } = await import('../../src/services/mcp/manager') - const stdioServer = { command: 'echo', args: [] } as any function staleHealthCheck(manager: any, name: string) { @@ -57,7 +49,7 @@ describe('MCPClientManager', () => { }) test('connects to new servers', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) const results = await manager.getClientsForServers({ alpha: stdioServer }) @@ -68,7 +60,7 @@ describe('MCPClientManager', () => { }) test('reuses connection when health check is not yet due', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) const servers = { alpha: stdioServer } const first = await manager.getClientsForServers(servers) @@ -79,7 +71,7 @@ describe('MCPClientManager', () => { }) test('reconnects when ping fails', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) const servers = { alpha: stdioServer } await manager.getClientsForServers(servers) @@ -102,7 +94,7 @@ describe('MCPClientManager', () => { }) test('closes removed servers by default (closeMissing=true)', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) await manager.getClientsForServers({ alpha: stdioServer, @@ -117,7 +109,7 @@ describe('MCPClientManager', () => { }) test('keeps removed servers when closeMissing=false', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) await manager.getClientsForServers({ alpha: stdioServer, @@ -135,7 +127,7 @@ describe('MCPClientManager', () => { }) test('reconnects when server config changes', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) await manager.getClientsForServers({ alpha: { command: 'echo', args: ['v1'] } as any, @@ -153,7 +145,7 @@ describe('MCPClientManager', () => { }) test('clear() closes all connections', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) await manager.getClientsForServers({ alpha: stdioServer, @@ -170,7 +162,7 @@ describe('MCPClientManager', () => { }) test('returns failed type when connection fails', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) mockConnectMcpServer.mockImplementationOnce(async () => ({ name: 'alpha', type: 'failed' as const, @@ -183,7 +175,7 @@ describe('MCPClientManager', () => { }) test('does not retry failed server within FAILED_RETRY_INTERVAL_MS', async () => { - const manager = new MCPClientManager() + const manager = new MCPClientManager(mockConnectMcpServer as any) mockConnectMcpServer.mockImplementation(async () => ({ name: 'alpha', type: 'failed' as const,