diff --git a/auth/authorization_code.go b/auth/authorization_code.go index a3daeecb..963884d0 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -605,7 +605,17 @@ func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context if err != nil { return fmt.Errorf("token exchange failed: %w", err) } - h.tokenSource = cfg.TokenSource(clientCtx, token) + // The token source outlives this authorization request: it is stored on the + // handler and used by the transport for the lifetime of the connection. The + // oauth2 library captures the context passed to TokenSource and reuses it for + // every subsequent token refresh (see golang.org/x/oauth2: tokenRefresher + // retains the context and passes it to each refresh round-trip). Binding it to + // the per-request ctx makes all later refreshes fail with "context canceled" + // once that request (or the connect operation that triggered authorization) + // completes. Use a background context that still carries the configured HTTP + // client so refreshes keep working for the life of the token source. + refreshCtx := context.WithValue(context.Background(), oauth2.HTTPClient, h.config.Client) + h.tokenSource = cfg.TokenSource(refreshCtx, token) return nil } diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go index c84b0032..15ca69f0 100644 --- a/auth/authorization_code_test.go +++ b/auth/authorization_code_test.go @@ -115,6 +115,101 @@ func TestAuthorize(t *testing.T) { } } +// TestAuthorize_RefreshAfterContextCancel verifies that the token source built +// by Authorize keeps refreshing after the context passed to Authorize is +// cancelled. The token source is stored on the handler and used by the +// transport for the whole connection lifetime, whereas Authorize is typically +// called from a request- or connect-scoped context. golang.org/x/oauth2 +// captures the context handed to Config.TokenSource and reuses it for every +// refresh, so binding it to the request context made every refresh after the +// access token expired fail with "context canceled". Regression test for that. +func TestAuthorize_RefreshAfterContextCancel(t *testing.T) { + authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + // expires_in below oauth2's 10s expiry delta, so the reuse token source + // treats the access token as expired immediately and must refresh. + AccessTokenTTL: 1, + IssueRefreshToken: true, + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test_client_id": { + Secret: "test_client_secret", + RedirectURIs: []string{"http://localhost:12345/callback"}, + }, + }, + }, + }) + authServer.Start(t) + + resourceMux := http.NewServeMux() + resourceServer := httptest.NewServer(resourceMux) + t.Cleanup(resourceServer.Close) + resourceURL := resourceServer.URL + "/resource" + resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + })) + + handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:12345/callback", + PreregisteredClient: &oauthex.ClientCredentials{ + ClientID: "test_client_id", + ClientSecretAuth: &oauthex.ClientSecretAuth{ClientSecret: "test_client_secret"}, + }, + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + client := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }} + resp, err := client.Get(args.URL) + if err != nil { + return nil, fmt.Errorf("failed to visit auth URL: %v", err) + } + defer resp.Body.Close() + location, err := resp.Location() + if err != nil { + return nil, fmt.Errorf("failed to get location header: %v", err) + } + return &AuthorizationResult{ + Code: location.Query().Get("code"), + State: location.Query().Get("state"), + Iss: location.Query().Get("iss"), + }, nil + }, + }) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, resourceURL, nil) + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + resp.Header.Set("WWW-Authenticate", "Bearer resource_metadata="+resourceServer.URL+"/.well-known/oauth-protected-resource/resource") + + // Authorize under a context that we cancel immediately afterwards, mimicking + // the request/connect context the transport passes and that is already done + // by the time a later token refresh runs. + ctx, cancel := context.WithCancel(context.Background()) + if err := handler.Authorize(ctx, req, resp); err != nil { + t.Fatalf("Authorize failed: %v", err) + } + cancel() + + tokenSource, err := handler.TokenSource(context.Background()) + if err != nil { + t.Fatalf("Failed to get token source: %v", err) + } + // The access token is already expired, so this forces a refresh round-trip. + // It must not fail with "context canceled" from the cancelled Authorize ctx. + token, err := tokenSource.Token() + if err != nil { + t.Fatalf("token refresh after Authorize context cancellation failed: %v", err) + } + if token.AccessToken != "test_access_token_refreshed" { + t.Errorf("expected refreshed access token %q, got %q", "test_access_token_refreshed", token.AccessToken) + } +} + func TestAuthorize_ScopeAccumulation(t *testing.T) { authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ RegistrationConfig: &oauthtest.RegistrationConfig{ diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go index c180f862..e5134fb5 100644 --- a/internal/oauthtest/fake_authorization_server.go +++ b/internal/oauthtest/fake_authorization_server.go @@ -89,6 +89,26 @@ type Config struct { // TokenScopeFunc, if set, is called with the scope from the authorization // request and returns the scope string to include in the token response. TokenScopeFunc func(requestedScope string) string + // AccessTokenTTL, if non-zero, is the expires_in (in seconds) returned by the + // /token endpoint for both the authorization_code and refresh_token grants. + // When zero a default of 3600 is used. Set it small to force a client's reuse + // token source to treat the access token as expired and refresh it. + AccessTokenTTL int + // IssueRefreshToken, if true, includes a refresh_token in token responses and + // enables grant_type=refresh_token at the /token endpoint. + IssueRefreshToken bool +} + +// testRefreshToken is the refresh token issued and accepted by the fake server +// when Config.IssueRefreshToken is set. +const testRefreshToken = "test_refresh_token" + +// accessTokenExpiresIn returns the expires_in value to use in token responses. +func (s *FakeAuthorizationServer) accessTokenExpiresIn() int { + if s.config.AccessTokenTTL != 0 { + return s.config.AccessTokenTTL + } + return 3600 } // FakeAuthorizationServer is a fake OAuth 2.0 Authorization Server for testing. @@ -298,6 +318,8 @@ func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Req s.handleJWTBearerGrant(w, r) case "client_credentials": s.handleClientCredentialsGrant(w, r) + case "refresh_token": + s.handleRefreshTokenGrant(w, r) default: http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) } @@ -329,7 +351,10 @@ func (s *FakeAuthorizationServer) handleAuthorizationCodeGrant(w http.ResponseWr resp := map[string]any{ "access_token": "test_access_token", "token_type": "Bearer", - "expires_in": 3600, + "expires_in": s.accessTokenExpiresIn(), + } + if s.config.IssueRefreshToken { + resp["refresh_token"] = testRefreshToken } if s.config.TokenScopeFunc != nil { if scope := s.config.TokenScopeFunc(codeInfo.Scope); scope != "" { @@ -340,6 +365,27 @@ func (s *FakeAuthorizationServer) handleAuthorizationCodeGrant(w http.ResponseWr json.NewEncoder(w).Encode(resp) } +// handleRefreshTokenGrant implements grant_type=refresh_token (RFC 6749 Section +// 6) when Config.IssueRefreshToken is set, returning a distinct access token so +// callers can observe that a refresh occurred. +func (s *FakeAuthorizationServer) handleRefreshTokenGrant(w http.ResponseWriter, r *http.Request) { + if !s.config.IssueRefreshToken { + http.Error(w, "refresh_token grant not supported", http.StatusBadRequest) + return + } + if r.Form.Get("refresh_token") != testRefreshToken { + http.Error(w, "invalid refresh_token", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test_access_token_refreshed", + "token_type": "Bearer", + "expires_in": s.accessTokenExpiresIn(), + "refresh_token": testRefreshToken, + }) +} + func (s *FakeAuthorizationServer) handleJWTBearerGrant(w http.ResponseWriter, r *http.Request) { if s.config.JWTBearerConfig == nil { http.Error(w, "JWT bearer grant not supported", http.StatusBadRequest)