Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions internal/server/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -1064,17 +1064,22 @@ func (p *MCPProxyServer) handleCallToolDestructive(ctx context.Context, request
}

// handleCallToolVariant is the common handler for all call_tool_* variants (Spec 018)
func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.CallToolRequest, toolVariant string) (*mcp.CallToolResult, error) {
func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.CallToolRequest, toolVariant string) (callResult *mcp.CallToolResult, callErr error) {
// Spec 024: Track start time and context for internal tool call logging
internalStartTime := time.Now()

// Add panic recovery to ensure server resilience
// Panic recovery with named return values to prevent returning (nil, nil).
// Issue #318: unnamed returns caused recover() to return zero values,
// triggering a second unrecovered panic in mcp-go when dereferencing nil *CallToolResult.
defer func() {
if r := recover(); r != nil {
p.logger.Error("Recovered from panic in handleCallToolVariant",
zap.Any("panic", r),
zap.String("tool_variant", toolVariant),
zap.Any("request", request))
zap.Any("request", request),
zap.Stack("stacktrace"))
callResult = mcp.NewToolResultError(fmt.Sprintf("Internal proxy error: recovered from panic: %v", r))
callErr = nil
}
}()

Expand Down Expand Up @@ -1326,15 +1331,21 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
tokenMetrics = &storage.TokenMetrics{
InputTokens: inputTokens,
Model: model,
Encoding: tokenizer.(*tokens.DefaultTokenizer).GetDefaultEncoding(),
Encoding: tokens.SafeGetDefaultEncoding(tokenizer),
}
}
}

// Derive server ID safely (serverConfig may be nil if storage lookup failed)
var serverID string
if serverConfig != nil {
serverID = storage.GenerateServerID(serverConfig)
}

// Record tool call for history (even if error)
toolCallRecord := &storage.ToolCallRecord{
ID: fmt.Sprintf("%d-%s", time.Now().UnixNano(), actualToolName),
ServerID: storage.GenerateServerID(serverConfig),
ServerID: serverID,
ServerName: serverName,
ToolName: actualToolName,
Arguments: args,
Expand Down Expand Up @@ -1493,13 +1504,16 @@ func (p *MCPProxyServer) handleCallToolVariant(ctx context.Context, request mcp.
}

// handleCallTool is the LEGACY call_tool handler - returns error directing to new variants (Spec 018)
func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Add panic recovery to ensure server resilience
func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToolRequest) (callResult *mcp.CallToolResult, callErr error) {
// Panic recovery with named return values (issue #318).
defer func() {
if r := recover(); r != nil {
p.logger.Error("Recovered from panic in handleCallTool",
zap.Any("panic", r),
zap.Any("request", request))
zap.Any("request", request),
zap.Stack("stacktrace"))
callResult = mcp.NewToolResultError(fmt.Sprintf("Internal proxy error: recovered from panic: %v", r))
callErr = nil
}
}()

Expand Down Expand Up @@ -1705,15 +1719,21 @@ func (p *MCPProxyServer) handleCallTool(ctx context.Context, request mcp.CallToo
tokenMetrics = &storage.TokenMetrics{
InputTokens: inputTokens,
Model: model,
Encoding: tokenizer.(*tokens.DefaultTokenizer).GetDefaultEncoding(),
Encoding: tokens.SafeGetDefaultEncoding(tokenizer),
}
}
}

// Derive server ID safely (serverConfig may be nil if storage lookup failed)
var serverID string
if serverConfig != nil {
serverID = storage.GenerateServerID(serverConfig)
}

// Record tool call for history (even if error)
toolCallRecord := &storage.ToolCallRecord{
ID: fmt.Sprintf("%d-%s", time.Now().UnixNano(), actualToolName),
ServerID: storage.GenerateServerID(serverConfig),
ServerID: serverID,
ServerName: serverName,
ToolName: actualToolName,
Arguments: args,
Expand Down
92 changes: 92 additions & 0 deletions internal/server/mcp_panic_recovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package server

import (
"context"
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/smart-mcp-proxy/mcpproxy-go/internal/contracts"
)

func TestHandleCallToolVariant_PanicRecovery(t *testing.T) {
// Issue #318: When handleCallToolVariant panics (e.g., nil tokenizer dereference),
// the recover() block must return a proper error result, not (nil, nil).
// Returning (nil, nil) causes a second unrecovered panic in mcp-go's HandleMessage
// when it dereferences the nil *CallToolResult.

t.Run("normal error path returns error result", func(t *testing.T) {
proxy := createTestMCPProxyServer(t)

request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
"name": "nonexistent-server:some_tool",
}

result, err := proxy.handleCallToolVariant(context.Background(), request, contracts.ToolVariantRead)

// Must never return (nil, nil)
assert.True(t, result != nil || err != nil,
"handleCallToolVariant must never return (nil, nil)")
if result != nil {
assert.True(t, result.IsError, "error results should have IsError=true")
}
})

t.Run("panic in handler returns error result via recover", func(t *testing.T) {
proxy := createTestMCPProxyServer(t)

// Nil out storage to trigger a guaranteed nil pointer panic inside
// handleCallToolVariant at p.storage.GetUpstreamServer(). This exercises
// the actual recover() path — the function panics, the deferred recover
// catches it, and sets callResult via named return values.
proxy.storage = nil

request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
"name": "panic-server:panic_tool",
}

result, err := proxy.handleCallToolVariant(context.Background(), request, contracts.ToolVariantRead)

// Before fix: would return (nil, nil), crashing mcp-go.
// After fix: must return a proper error result.
require.NoError(t, err, "recovered panic should not return an error")
require.NotNil(t, result, "recovered panic must return a non-nil CallToolResult, not (nil, nil)")
assert.True(t, result.IsError, "recovered panic should be an error result")

// Verify the error message mentions the panic
if len(result.Content) > 0 {
if text, ok := result.Content[0].(mcp.TextContent); ok {
assert.Contains(t, text.Text, "Internal proxy error")
assert.Contains(t, text.Text, "recovered from panic")
}
}
})

t.Run("panic in legacy handleCallTool returns error result via recover", func(t *testing.T) {
proxy := createTestMCPProxyServer(t)

// Same approach: nil storage triggers panic in handleCallTool
proxy.storage = nil

request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]interface{}{
"name": "panic-server:panic_tool",
}

result, err := proxy.handleCallTool(context.Background(), request)

require.NoError(t, err, "recovered panic should not return an error")
require.NotNil(t, result, "recovered panic must return a non-nil CallToolResult, not (nil, nil)")
assert.True(t, result.IsError, "recovered panic should be an error result")

if len(result.Content) > 0 {
if text, ok := result.Content[0].(mcp.TextContent); ok {
assert.Contains(t, text.Text, "Internal proxy error")
}
}
})
}
32 changes: 27 additions & 5 deletions internal/server/tokens/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@ type DefaultTokenizer struct {
enabled bool
}

// NewTokenizer creates a new tokenizer instance
// NewTokenizer creates a new tokenizer instance.
// When enabled=false, encoding validation is skipped so that a disabled
// tokenizer can always be created (even with a corrupted tiktoken cache).
// See issue #318.
func NewTokenizer(defaultEncoding string, logger *zap.SugaredLogger, enabled bool) (*DefaultTokenizer, error) {
if defaultEncoding == "" {
defaultEncoding = DefaultEncoding
}

// Validate encoding exists
_, err := tiktoken.GetEncoding(defaultEncoding)
if err != nil {
return nil, fmt.Errorf("invalid encoding %q: %w", defaultEncoding, err)
// Only validate encoding when tokenizer is enabled.
// A disabled tokenizer never calls tiktoken, so a corrupted cache is harmless.
if enabled {
_, err := tiktoken.GetEncoding(defaultEncoding)
if err != nil {
return nil, fmt.Errorf("invalid encoding %q: %w", defaultEncoding, err)
}
}

return &DefaultTokenizer{
Expand Down Expand Up @@ -191,3 +197,19 @@ func (t *DefaultTokenizer) GetDefaultEncoding() string {
defer t.mu.RUnlock()
return t.defaultEncoding
}

// SafeGetDefaultEncoding extracts the default encoding from a Tokenizer interface
// without risking a panic. Handles nil interface, nil underlying *DefaultTokenizer,
// and non-DefaultTokenizer implementations. Returns DefaultEncoding as fallback.
// See issue #318: a nil *DefaultTokenizer assigned to a Tokenizer interface creates
// a non-nil interface with nil concrete value, causing panics on type assertion.
func SafeGetDefaultEncoding(t Tokenizer) string {
if t == nil {
return DefaultEncoding
}
dt, ok := t.(*DefaultTokenizer)
if !ok || dt == nil {
return DefaultEncoding
}
return dt.GetDefaultEncoding()
}
67 changes: 67 additions & 0 deletions internal/server/tokens/tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,70 @@ func TestSupportedEncodings(t *testing.T) {
assert.Contains(t, encodings, "p50k_base")
assert.Contains(t, encodings, "r50k_base")
}

func TestNewTokenizerDisabledWithInvalidEncoding(t *testing.T) {
logger := zap.NewNop().Sugar()

// When disabled, NewTokenizer should succeed even with an invalid encoding
// (e.g., tiktoken cache is corrupted). This prevents the fallback path
// from also failing and leaving a nil tokenizer. See issue #318.
tokenizer, err := NewTokenizer("invalid_encoding_corrupted", logger, false)
require.NoError(t, err, "disabled tokenizer should not validate encoding")
require.NotNil(t, tokenizer, "disabled tokenizer should never be nil")
assert.False(t, tokenizer.IsEnabled())

// All methods should return zero without error when disabled
count, err := tokenizer.CountTokens("Hello, world!")
require.NoError(t, err)
assert.Equal(t, 0, count)

count, err = tokenizer.CountTokensForModel("Hello", "gpt-4")
require.NoError(t, err)
assert.Equal(t, 0, count)

count, err = tokenizer.CountTokensInJSON(map[string]string{"key": "value"})
require.NoError(t, err)
assert.Equal(t, 0, count)
}

func TestNilTokenizerInterfaceSafety(t *testing.T) {
// Simulates the scenario from issue #318 where a nil *DefaultTokenizer
// is assigned to a Tokenizer interface, creating a non-nil interface
// with a nil underlying value. SafeGetDefaultEncoding must handle this.
var nilTokenizer *DefaultTokenizer

// Assign nil concrete to interface — interface is non-nil but underlying is nil.
// In Go, the interface value itself is non-nil (it holds type info), but the
// concrete pointer inside is nil. This is the exact bug scenario.
var iface Tokenizer = nilTokenizer
_ = iface // used below

// Calling methods on this would panic — that's the bug.
// SafeGetDefaultEncoding should handle this safely without panic.
encoding := SafeGetDefaultEncoding(iface)
assert.Equal(t, DefaultEncoding, encoding, "should return default encoding for nil underlying")
}

func TestSafeGetDefaultEncoding(t *testing.T) {
logger := zap.NewNop().Sugar()

t.Run("valid DefaultTokenizer", func(t *testing.T) {
tokenizer, err := NewTokenizer("o200k_base", logger, true)
require.NoError(t, err)

encoding := SafeGetDefaultEncoding(tokenizer)
assert.Equal(t, "o200k_base", encoding)
})

t.Run("nil interface", func(t *testing.T) {
encoding := SafeGetDefaultEncoding(nil)
assert.Equal(t, DefaultEncoding, encoding)
})

t.Run("nil DefaultTokenizer in interface", func(t *testing.T) {
var nilDT *DefaultTokenizer
var iface Tokenizer = nilDT
encoding := SafeGetDefaultEncoding(iface)
assert.Equal(t, DefaultEncoding, encoding)
})
}