Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions internal/oauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -870,27 +870,20 @@ func createOAuthConfigInternal(serverConfig *config.ServerConfig, storage *stora
zap.String("storage", "memory"))
}

// Create HTTP client with transport wrapper to inject extra params into token requests
// extraParams may contain auto-detected resource (RFC 8707) or manual config params
var httpClient *http.Client

// Create HTTP client with transport wrapper for all OAuth servers.
// The wrapper injects extra params (if any) and normalizes non-standard
// token response status codes (e.g., 201 Created from Supabase → 200 OK).
if len(extraParams) > 0 {
// Log extra params with selective masking for security
masked := maskExtraParams(extraParams)
logger.Debug("OAuth extra parameters will be injected into token requests",
zap.String("server", serverConfig.Name),
zap.Any("extra_params", masked))
}

// Create HTTP client with wrapper to inject extra params into token exchange/refresh
wrapper := NewOAuthTransportWrapper(http.DefaultTransport, extraParams, logger)
httpClient = &http.Client{
Transport: wrapper,
Timeout: 30 * time.Second,
}

logger.Info("✅ Created OAuth HTTP client with extra params wrapper for token requests",
zap.String("server", serverConfig.Name),
zap.Int("extra_params_count", len(extraParams)))
wrapper := NewOAuthTransportWrapper(http.DefaultTransport, extraParams, logger)
httpClient := &http.Client{
Transport: wrapper,
Timeout: 30 * time.Second,
}

// Check if static OAuth credentials are provided in config
Expand Down Expand Up @@ -946,7 +939,7 @@ func createOAuthConfigInternal(serverConfig *config.ServerConfig, storage *stora
TokenStore: tokenStore, // Shared token store for this server
PKCEEnabled: true, // Always enable PKCE for security
AuthServerMetadataURL: authServerMetadataURL, // Explicit metadata URL for proper discovery
HTTPClient: httpClient, // Custom HTTP client with extra params wrapper (if configured)
HTTPClient: httpClient, // Custom HTTP client with transport wrapper (extra params + status normalization)
}

logger.Info("OAuth config created successfully",
Expand Down
55 changes: 42 additions & 13 deletions internal/oauth/transport_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,54 @@ func NewOAuthTransportWrapper(transport http.RoundTripper, extraParams map[strin
// 2. Clones the request to avoid modifying the original
// 3. Injects extra parameters into query string (authorization) or body (token)
// 4. Delegates to the wrapped transport for actual HTTP execution
// 5. Logs parameter injection at DEBUG level for observability
// 5. Normalizes HTTP 201 responses to 200 for token requests (some providers like Supabase return 201)
// 6. Logs parameter injection at DEBUG level for observability
func (w *OAuthTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
// Skip if no extra params configured
if len(w.extraParams) == 0 {
return w.inner.RoundTrip(req)
tokenReq := isTokenRequest(req)

if len(w.extraParams) > 0 {
// Clone request to avoid modifying original
clonedReq := req.Clone(req.Context())

// Detect OAuth endpoint type and inject params appropriately
if isAuthorizationRequest(req) {
w.injectQueryParams(clonedReq)
} else if tokenReq {
w.injectFormParams(clonedReq)
}

resp, err := w.inner.RoundTrip(clonedReq)
if err != nil {
return resp, err
}

// Normalize 201 Created to 200 OK for token responses.
// Some OAuth providers (e.g., Supabase) return 201 for token exchange,
// but mcp-go only accepts 200.
if tokenReq && resp.StatusCode == http.StatusCreated {
w.logger.Debug("Normalized token response status 201→200",
zap.String("url", req.URL.String()))
resp.StatusCode = http.StatusOK
resp.Status = "200 OK"
}

return resp, nil
}

// Clone request to avoid modifying original
clonedReq := req.Clone(req.Context())
resp, err := w.inner.RoundTrip(req)
if err != nil {
return resp, err
}

// Detect OAuth endpoint type and inject params appropriately
if isAuthorizationRequest(req) {
w.injectQueryParams(clonedReq)
} else if isTokenRequest(req) {
w.injectFormParams(clonedReq)
// Normalize 201 Created to 200 OK for token responses even without extra params.
if tokenReq && resp.StatusCode == http.StatusCreated {
w.logger.Debug("Normalized token response status 201→200",
zap.String("url", req.URL.String()))
resp.StatusCode = http.StatusOK
resp.Status = "200 OK"
}

// Delegate to wrapped transport
return w.inner.RoundTrip(clonedReq)
return resp, nil
}

// isAuthorizationRequest detects if this is an OAuth authorization request
Expand Down
64 changes: 64 additions & 0 deletions internal/oauth/transport_wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,70 @@ func TestInjectFormParams_EmptyBody(t *testing.T) {
assert.Equal(t, "https://example.com/mcp", actualParams.Get("resource"))
}

func TestRoundTrip_Normalizes201ToOKForTokenRequests(t *testing.T) {
tests := []struct {
name string
method string
url string
statusCode int
extraParams map[string]string
expectedStatus int
}{
{
name: "201 token response normalized to 200 with extra params",
method: "POST",
url: "https://provider.com/token",
statusCode: http.StatusCreated,
extraParams: map[string]string{"resource": "https://example.com"},
expectedStatus: http.StatusOK,
},
{
name: "201 token response normalized to 200 without extra params",
method: "POST",
url: "https://provider.com/token",
statusCode: http.StatusCreated,
extraParams: nil,
expectedStatus: http.StatusOK,
},
{
name: "200 token response unchanged",
method: "POST",
url: "https://provider.com/token",
statusCode: http.StatusOK,
extraParams: nil,
expectedStatus: http.StatusOK,
},
{
name: "201 non-token response not normalized",
method: "GET",
url: "https://provider.com/authorize",
statusCode: http.StatusCreated,
extraParams: nil,
expectedStatus: http.StatusCreated,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inner := &mockRoundTripper{
response: &http.Response{
StatusCode: tt.statusCode,
Status: http.StatusText(tt.statusCode),
Body: io.NopCloser(strings.NewReader(`{"access_token":"test"}`)),
Header: make(http.Header),
},
}
wrapper := NewOAuthTransportWrapper(inner, tt.extraParams, zap.NewNop())
req := httptest.NewRequest(tt.method, tt.url, strings.NewReader("grant_type=authorization_code"))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := wrapper.RoundTrip(req)
require.NoError(t, err)
assert.Equal(t, tt.expectedStatus, resp.StatusCode)
})
}
}

func TestRoundTrip_PreservesOriginalRequest(t *testing.T) {
inner := &mockRoundTripper{}
extraParams := map[string]string{"resource": "https://example.com/mcp"}
Expand Down