From 542e7de5f952371302d8225372a6093cf6886fa4 Mon Sep 17 00:00:00 2001 From: Claude Code Date: Mon, 9 Mar 2026 16:37:52 +0200 Subject: [PATCH] fix: prevent _auth_ metadata from leaking to upstream MCP servers (#322) injectAuthMetadata() was mutating the args map in-place, causing _auth_* fields to be forwarded to upstream servers via CallTool(). FastMCP-based servers using Pydantic's validate_call reject unknown keyword arguments, breaking tool calls for authenticated users. Fix: injectAuthMetadata() now returns a shallow copy. All call sites use separate activityArgs for logging and clean args for upstream forwarding. Closes #322 Co-Authored-By: Claude Opus 4.6 --- internal/server/mcp.go | 71 ++++++++++++---------- internal/server/mcp_activity_agent_test.go | 29 +++++++++ 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 17aee5fa..7a4b6647 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -256,20 +256,23 @@ func getAuthMetadata(ctx context.Context) map[string]string { return meta } -// injectAuthMetadata merges auth identity metadata into an activity arguments map (Spec 028). +// injectAuthMetadata creates a shallow copy of args with auth identity metadata added (Spec 028). // Uses "_auth_" prefix to clearly separate auth metadata from tool arguments. +// Returns a new map — the original args map is never mutated, so it remains +// safe to forward to upstream servers without leaking internal metadata. func injectAuthMetadata(ctx context.Context, args map[string]interface{}) map[string]interface{} { authMeta := getAuthMetadata(ctx) if authMeta == nil { return args } - if args == nil { - args = make(map[string]interface{}) + enriched := make(map[string]interface{}, len(args)+len(authMeta)) + for k, v := range args { + enriched[k] = v } for k, v := range authMeta { - args["_auth_"+k] = v + enriched["_auth_"+k] = v } - return args + return enriched } // emitActivityEvent safely emits an activity event if runtime is available @@ -971,17 +974,17 @@ func (p *MCPProxyServer) handleRetrieveTools(ctx context.Context, request mcp.Ca } } - // Spec 028: Inject auth identity into activity metadata - args = injectAuthMetadata(ctx, args) + // Spec 028: Inject auth identity into activity metadata (separate copy for logging only) + activityArgs := injectAuthMetadata(ctx, args) jsonResult, err := json.Marshal(response) if err != nil { - p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "error", err.Error(), time.Since(startTime).Milliseconds(), args, nil, nil) + p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "error", err.Error(), time.Since(startTime).Milliseconds(), activityArgs, nil, nil) return mcp.NewToolResultError(fmt.Sprintf("Failed to serialize results: %v", err)), nil } // Emit success event with args and response (Spec 024) - p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "success", "", time.Since(startTime).Milliseconds(), args, response, nil) + p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "success", "", time.Since(startTime).Milliseconds(), activityArgs, response, nil) return mcp.NewToolResultText(string(jsonResult)), nil } @@ -1236,8 +1239,10 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. // Generate requestID for activity tracking requestID := fmt.Sprintf("%d-%s-%s", time.Now().UnixNano(), serverName, actualToolName) - // Spec 028: Inject auth identity into activity metadata - args = injectAuthMetadata(ctx, args) + // Spec 028: Inject auth identity into a separate copy for activity logging only. + // The original args must not be mutated — upstream servers reject unknown fields + // (e.g. FastMCP's Pydantic validate_call). See #322. + activityArgs := injectAuthMetadata(ctx, args) // Check if server is quarantined before calling tool serverConfig, err := p.storage.GetUpstreamServer(serverName) @@ -1249,7 +1254,7 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. p.emitActivityPolicyDecision(serverName, actualToolName, getSessionID(), "blocked", "Server is quarantined for security review") // Server is in quarantine - return security warning with tool analysis - return p.handleQuarantinedToolCall(ctx, serverName, actualToolName, args), nil + return p.handleQuarantinedToolCall(ctx, serverName, actualToolName, activityArgs), nil } // Check connection status before attempting tool call to prevent hanging @@ -1267,8 +1272,8 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. if intent != nil { intentMap = intent.ToMap() } - p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args) - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, toolVariant, intentMap) + p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, toolVariant, intentMap) return mcp.NewToolResultError(errMsg), nil } } else { @@ -1292,15 +1297,15 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. if intent != nil { intentMap = intent.ToMap() } - p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args) - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, toolVariant, intentMap) + p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, toolVariant, intentMap) return mcp.NewToolResultError(errMsg), nil } // Emit activity started event with determined source - p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args) + p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs) - // Call tool via upstream manager with circuit breaker pattern + // Call tool via upstream manager — use original args without auth metadata startTime := time.Now() result, err := p.upstreamManager.CallTool(ctx, toolName, args) duration := time.Since(startTime) @@ -1348,7 +1353,7 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. ServerID: serverID, ServerName: serverName, ToolName: actualToolName, - Arguments: args, + Arguments: activityArgs, Duration: int64(duration), Timestamp: startTime, ConfigPath: p.mainServer.GetConfigPath(), @@ -1390,11 +1395,11 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. if intent != nil { intentMap = intent.ToMap() } - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), args, "", false, toolVariant, intentMap) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), activityArgs, "", false, toolVariant, intentMap) // Spec 024: Emit internal tool call event for error internalToolName := "call_tool_" + intent.OperationType // e.g., "call_tool_read" - p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "error", err.Error(), time.Since(internalStartTime).Milliseconds(), args, nil, intentMap) + p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "error", err.Error(), time.Since(internalStartTime).Milliseconds(), activityArgs, nil, intentMap) return p.createDetailedErrorResponse(err, serverName, actualToolName), nil } @@ -1494,11 +1499,11 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp. if intent != nil { intentMap = intent.ToMap() } - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), args, response, responseTruncated, toolVariant, intentMap) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), activityArgs, response, responseTruncated, toolVariant, intentMap) // Spec 024: Emit internal tool call event for success internalToolName := "call_tool_" + intent.OperationType // e.g., "call_tool_read" - p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "success", "", time.Since(internalStartTime).Milliseconds(), args, result, intentMap) + p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "success", "", time.Since(internalStartTime).Milliseconds(), activityArgs, result, intentMap) return mcp.NewToolResultText(response), nil } @@ -1631,8 +1636,8 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo // Generate requestID for activity tracking requestID := fmt.Sprintf("%d-%s-%s", time.Now().UnixNano(), serverName, actualToolName) - // Spec 028: Inject auth identity into activity metadata - args = injectAuthMetadata(ctx, args) + // Spec 028: Inject auth identity into activity metadata (separate copy for logging only) + activityArgs := injectAuthMetadata(ctx, args) // Check if server is quarantined before calling tool serverConfig, err := p.storage.GetUpstreamServer(serverName) @@ -1666,8 +1671,8 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo errMsg = fmt.Sprintf("Server '%s' is not connected (state: %s) - use 'upstream_servers' tool to check server configuration", serverName, state.String()) } // Log the early failure to activity (Spec 024) - p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args) - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, "", nil) + p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, "", nil) return mcp.NewToolResultError(errMsg), nil } } else { @@ -1675,8 +1680,8 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo zap.String("server_name", serverName)) errMsg := fmt.Sprintf("No client found for server: %s", serverName) // Log the early failure to activity (Spec 024) - p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args) - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, "", nil) + p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, "", nil) return mcp.NewToolResultError(errMsg), nil } @@ -1685,7 +1690,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo zap.String("server_name", serverName)) // Emit activity started event with determined source - p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args) + p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs) // Call tool via upstream manager with circuit breaker pattern startTime := time.Now() @@ -1736,7 +1741,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo ServerID: serverID, ServerName: serverName, ToolName: actualToolName, - Arguments: args, + Arguments: activityArgs, Duration: int64(duration), Timestamp: startTime, ConfigPath: p.mainServer.GetConfigPath(), @@ -1794,7 +1799,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo } // Emit activity completed event for error with determined source (legacy - no intent) - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), args, "", false, "", nil) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), activityArgs, "", false, "", nil) return p.createDetailedErrorResponse(err, serverName, actualToolName), nil } @@ -1890,7 +1895,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo // Emit activity completed event for success with determined source (legacy - no intent) responseTruncated := tokenMetrics != nil && tokenMetrics.WasTruncated - p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), args, response, responseTruncated, "", nil) + p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), activityArgs, response, responseTruncated, "", nil) return mcp.NewToolResultText(response), nil } diff --git a/internal/server/mcp_activity_agent_test.go b/internal/server/mcp_activity_agent_test.go index ac813cc8..60aeef60 100644 --- a/internal/server/mcp_activity_agent_test.go +++ b/internal/server/mcp_activity_agent_test.go @@ -110,6 +110,35 @@ func TestInjectAuthMetadata_NilArgs(t *testing.T) { assert.Equal(t, "mcp_123", result["_auth_token_prefix"]) } +// TestInjectAuthMetadata_DoesNotMutateOriginal verifies the original args map is never modified. +// This is the core fix for issue #322: _auth_* fields must not leak to upstream MCP servers. +func TestInjectAuthMetadata_DoesNotMutateOriginal(t *testing.T) { + agentCtx := &auth.AuthContext{ + Type: auth.AuthTypeAgent, + AgentName: "test-bot", + TokenPrefix: "mcp_xyz789abc", + } + ctx := auth.WithAuthContext(context.Background(), agentCtx) + + args := map[string]interface{}{ + "query": "search term", + "limit": 10, + } + + result := injectAuthMetadata(ctx, args) + + // Result should have auth metadata + assert.Equal(t, "agent", result["_auth_auth_type"]) + assert.Equal(t, "test-bot", result["_auth_agent_name"]) + + // Original args must NOT be modified + assert.Len(t, args, 2, "original args should still have exactly 2 keys") + _, hasAuthType := args["_auth_auth_type"] + assert.False(t, hasAuthType, "original args must not contain _auth_ fields (issue #322)") + _, hasAgentName := args["_auth_agent_name"] + assert.False(t, hasAgentName, "original args must not contain _auth_ fields (issue #322)") +} + // TestInjectAuthMetadata_NoAuthContext verifies no injection when no auth context. func TestInjectAuthMetadata_NoAuthContext(t *testing.T) { ctx := context.Background()