Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 38 additions & 33 deletions internal/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,23 @@ func getAuthMetadata(ctx context.Context) map[string]string {
return meta
}

// injectAuthMetadata merges auth identity metadata into an activity arguments map (Spec 028).
// injectAuthMetadata creates a shallow copy of args with auth identity metadata added (Spec 028).
// Uses "_auth_" prefix to clearly separate auth metadata from tool arguments.
// Returns a new map — the original args map is never mutated, so it remains
// safe to forward to upstream servers without leaking internal metadata.
func injectAuthMetadata(ctx context.Context, args map[string]interface{}) map[string]interface{} {
authMeta := getAuthMetadata(ctx)
if authMeta == nil {
return args
}
if args == nil {
args = make(map[string]interface{})
enriched := make(map[string]interface{}, len(args)+len(authMeta))
for k, v := range args {
enriched[k] = v
}
for k, v := range authMeta {
args["_auth_"+k] = v
enriched["_auth_"+k] = v
}
return args
return enriched
}

// emitActivityEvent safely emits an activity event if runtime is available
Expand Down Expand Up @@ -971,17 +974,17 @@ func (p *MCPProxyServer) handleRetrieveTools(ctx context.Context, request mcp.Ca
}
}

// Spec 028: Inject auth identity into activity metadata
args = injectAuthMetadata(ctx, args)
// Spec 028: Inject auth identity into activity metadata (separate copy for logging only)
activityArgs := injectAuthMetadata(ctx, args)

jsonResult, err := json.Marshal(response)
if err != nil {
p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "error", err.Error(), time.Since(startTime).Milliseconds(), args, nil, nil)
p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "error", err.Error(), time.Since(startTime).Milliseconds(), activityArgs, nil, nil)
return mcp.NewToolResultError(fmt.Sprintf("Failed to serialize results: %v", err)), nil
}

// Emit success event with args and response (Spec 024)
p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "success", "", time.Since(startTime).Milliseconds(), args, response, nil)
p.emitActivityInternalToolCall("retrieve_tools", "", "", "", sessionID, requestID, "success", "", time.Since(startTime).Milliseconds(), activityArgs, response, nil)

return mcp.NewToolResultText(string(jsonResult)), nil
}
Expand Down Expand Up @@ -1236,8 +1239,10 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
// Generate requestID for activity tracking
requestID := fmt.Sprintf("%d-%s-%s", time.Now().UnixNano(), serverName, actualToolName)

// Spec 028: Inject auth identity into activity metadata
args = injectAuthMetadata(ctx, args)
// Spec 028: Inject auth identity into a separate copy for activity logging only.
// The original args must not be mutated — upstream servers reject unknown fields
// (e.g. FastMCP's Pydantic validate_call). See #322.
activityArgs := injectAuthMetadata(ctx, args)

// Check if server is quarantined before calling tool
serverConfig, err := p.storage.GetUpstreamServer(serverName)
Expand All @@ -1249,7 +1254,7 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
p.emitActivityPolicyDecision(serverName, actualToolName, getSessionID(), "blocked", "Server is quarantined for security review")

// Server is in quarantine - return security warning with tool analysis
return p.handleQuarantinedToolCall(ctx, serverName, actualToolName, args), nil
return p.handleQuarantinedToolCall(ctx, serverName, actualToolName, activityArgs), nil
}

// Check connection status before attempting tool call to prevent hanging
Expand All @@ -1267,8 +1272,8 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
if intent != nil {
intentMap = intent.ToMap()
}
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, toolVariant, intentMap)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, toolVariant, intentMap)
return mcp.NewToolResultError(errMsg), nil
}
} else {
Expand All @@ -1292,15 +1297,15 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
if intent != nil {
intentMap = intent.ToMap()
}
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, toolVariant, intentMap)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, toolVariant, intentMap)
return mcp.NewToolResultError(errMsg), nil
}

// Emit activity started event with determined source
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs)

// Call tool via upstream manager with circuit breaker pattern
// Call tool via upstream manager — use original args without auth metadata
startTime := time.Now()
result, err := p.upstreamManager.CallTool(ctx, toolName, args)
duration := time.Since(startTime)
Expand Down Expand Up @@ -1348,7 +1353,7 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
ServerID: serverID,
ServerName: serverName,
ToolName: actualToolName,
Arguments: args,
Arguments: activityArgs,
Duration: int64(duration),
Timestamp: startTime,
ConfigPath: p.mainServer.GetConfigPath(),
Expand Down Expand Up @@ -1390,11 +1395,11 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
if intent != nil {
intentMap = intent.ToMap()
}
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), args, "", false, toolVariant, intentMap)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), activityArgs, "", false, toolVariant, intentMap)

// Spec 024: Emit internal tool call event for error
internalToolName := "call_tool_" + intent.OperationType // e.g., "call_tool_read"
p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "error", err.Error(), time.Since(internalStartTime).Milliseconds(), args, nil, intentMap)
p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "error", err.Error(), time.Since(internalStartTime).Milliseconds(), activityArgs, nil, intentMap)

return p.createDetailedErrorResponse(err, serverName, actualToolName), nil
}
Expand Down Expand Up @@ -1494,11 +1499,11 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
if intent != nil {
intentMap = intent.ToMap()
}
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), args, response, responseTruncated, toolVariant, intentMap)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), activityArgs, response, responseTruncated, toolVariant, intentMap)

// Spec 024: Emit internal tool call event for success
internalToolName := "call_tool_" + intent.OperationType // e.g., "call_tool_read"
p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "success", "", time.Since(internalStartTime).Milliseconds(), args, result, intentMap)
p.emitActivityInternalToolCall(internalToolName, serverName, actualToolName, toolVariant, sessionID, requestID, "success", "", time.Since(internalStartTime).Milliseconds(), activityArgs, result, intentMap)

return mcp.NewToolResultText(response), nil
}
Expand Down Expand Up @@ -1631,8 +1636,8 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo
// Generate requestID for activity tracking
requestID := fmt.Sprintf("%d-%s-%s", time.Now().UnixNano(), serverName, actualToolName)

// Spec 028: Inject auth identity into activity metadata
args = injectAuthMetadata(ctx, args)
// Spec 028: Inject auth identity into activity metadata (separate copy for logging only)
activityArgs := injectAuthMetadata(ctx, args)

// Check if server is quarantined before calling tool
serverConfig, err := p.storage.GetUpstreamServer(serverName)
Expand Down Expand Up @@ -1666,17 +1671,17 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo
errMsg = fmt.Sprintf("Server '%s' is not connected (state: %s) - use 'upstream_servers' tool to check server configuration", serverName, state.String())
}
// Log the early failure to activity (Spec 024)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, "", nil)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, "", nil)
return mcp.NewToolResultError(errMsg), nil
}
} else {
p.logger.Error("handleCallTool: no client found for server",
zap.String("server_name", serverName))
errMsg := fmt.Sprintf("No client found for server: %s", serverName)
// Log the early failure to activity (Spec 024)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, args, errMsg, false, "", nil)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", errMsg, 0, activityArgs, errMsg, false, "", nil)
return mcp.NewToolResultError(errMsg), nil
}

Expand All @@ -1685,7 +1690,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo
zap.String("server_name", serverName))

// Emit activity started event with determined source
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, args)
p.emitActivityToolCallStarted(serverName, actualToolName, sessionID, requestID, activitySource, activityArgs)

// Call tool via upstream manager with circuit breaker pattern
startTime := time.Now()
Expand Down Expand Up @@ -1736,7 +1741,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo
ServerID: serverID,
ServerName: serverName,
ToolName: actualToolName,
Arguments: args,
Arguments: activityArgs,
Duration: int64(duration),
Timestamp: startTime,
ConfigPath: p.mainServer.GetConfigPath(),
Expand Down Expand Up @@ -1794,7 +1799,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo
}

// Emit activity completed event for error with determined source (legacy - no intent)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), args, "", false, "", nil)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "error", err.Error(), duration.Milliseconds(), activityArgs, "", false, "", nil)

return p.createDetailedErrorResponse(err, serverName, actualToolName), nil
}
Expand Down Expand Up @@ -1890,7 +1895,7 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo

// Emit activity completed event for success with determined source (legacy - no intent)
responseTruncated := tokenMetrics != nil && tokenMetrics.WasTruncated
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), args, response, responseTruncated, "", nil)
p.emitActivityToolCallCompleted(serverName, actualToolName, sessionID, requestID, activitySource, "success", "", duration.Milliseconds(), activityArgs, response, responseTruncated, "", nil)

return mcp.NewToolResultText(response), nil
}
Expand Down
29 changes: 29 additions & 0 deletions internal/server/mcp_activity_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,35 @@ func TestInjectAuthMetadata_NilArgs(t *testing.T) {
assert.Equal(t, "mcp_123", result["_auth_token_prefix"])
}

// TestInjectAuthMetadata_DoesNotMutateOriginal verifies the original args map is never modified.
// This is the core fix for issue #322: _auth_* fields must not leak to upstream MCP servers.
func TestInjectAuthMetadata_DoesNotMutateOriginal(t *testing.T) {
agentCtx := &auth.AuthContext{
Type: auth.AuthTypeAgent,
AgentName: "test-bot",
TokenPrefix: "mcp_xyz789abc",
}
ctx := auth.WithAuthContext(context.Background(), agentCtx)

args := map[string]interface{}{
"query": "search term",
"limit": 10,
}

result := injectAuthMetadata(ctx, args)

// Result should have auth metadata
assert.Equal(t, "agent", result["_auth_auth_type"])
assert.Equal(t, "test-bot", result["_auth_agent_name"])

// Original args must NOT be modified
assert.Len(t, args, 2, "original args should still have exactly 2 keys")
_, hasAuthType := args["_auth_auth_type"]
assert.False(t, hasAuthType, "original args must not contain _auth_ fields (issue #322)")
_, hasAgentName := args["_auth_agent_name"]
assert.False(t, hasAgentName, "original args must not contain _auth_ fields (issue #322)")
}

// TestInjectAuthMetadata_NoAuthContext verifies no injection when no auth context.
func TestInjectAuthMetadata_NoAuthContext(t *testing.T) {
ctx := context.Background()
Expand Down