Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion auth/authorization_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,17 @@ func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context
if err != nil {
return fmt.Errorf("token exchange failed: %w", err)
}
h.tokenSource = cfg.TokenSource(clientCtx, token)
// The token source outlives this authorization request: it is stored on the
// handler and used by the transport for the lifetime of the connection. The
// oauth2 library captures the context passed to TokenSource and reuses it for
// every subsequent token refresh (see golang.org/x/oauth2: tokenRefresher
// retains the context and passes it to each refresh round-trip). Binding it to
// the per-request ctx makes all later refreshes fail with "context canceled"
// once that request (or the connect operation that triggered authorization)
// completes. Use a background context that still carries the configured HTTP
// client so refreshes keep working for the life of the token source.
refreshCtx := context.WithValue(context.Background(), oauth2.HTTPClient, h.config.Client)
h.tokenSource = cfg.TokenSource(refreshCtx, token)
return nil
}

Expand Down
95 changes: 95 additions & 0 deletions auth/authorization_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,101 @@ func TestAuthorize(t *testing.T) {
}
}

// TestAuthorize_RefreshAfterContextCancel verifies that the token source built
// by Authorize keeps refreshing after the context passed to Authorize is
// cancelled. The token source is stored on the handler and used by the
// transport for the whole connection lifetime, whereas Authorize is typically
// called from a request- or connect-scoped context. golang.org/x/oauth2
// captures the context handed to Config.TokenSource and reuses it for every
// refresh, so binding it to the request context made every refresh after the
// access token expired fail with "context canceled". Regression test for that.
func TestAuthorize_RefreshAfterContextCancel(t *testing.T) {
authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{
// expires_in below oauth2's 10s expiry delta, so the reuse token source
// treats the access token as expired immediately and must refresh.
AccessTokenTTL: 1,
IssueRefreshToken: true,
RegistrationConfig: &oauthtest.RegistrationConfig{
PreregisteredClients: map[string]oauthtest.ClientInfo{
"test_client_id": {
Secret: "test_client_secret",
RedirectURIs: []string{"http://localhost:12345/callback"},
},
},
},
})
authServer.Start(t)

resourceMux := http.NewServeMux()
resourceServer := httptest.NewServer(resourceMux)
t.Cleanup(resourceServer.Close)
resourceURL := resourceServer.URL + "/resource"
resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{
Resource: resourceURL,
AuthorizationServers: []string{authServer.URL()},
}))

handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{
RedirectURL: "http://localhost:12345/callback",
PreregisteredClient: &oauthex.ClientCredentials{
ClientID: "test_client_id",
ClientSecretAuth: &oauthex.ClientSecretAuth{ClientSecret: "test_client_secret"},
},
AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) {
client := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse }}
resp, err := client.Get(args.URL)
if err != nil {
return nil, fmt.Errorf("failed to visit auth URL: %v", err)
}
defer resp.Body.Close()
location, err := resp.Location()
if err != nil {
return nil, fmt.Errorf("failed to get location header: %v", err)
}
return &AuthorizationResult{
Code: location.Query().Get("code"),
State: location.Query().Get("state"),
Iss: location.Query().Get("iss"),
}, nil
},
})
if err != nil {
t.Fatalf("NewAuthorizationCodeHandler failed: %v", err)
}

req := httptest.NewRequest(http.MethodGet, resourceURL, nil)
resp := &http.Response{
StatusCode: http.StatusUnauthorized,
Header: make(http.Header),
Body: http.NoBody,
Request: req,
}
resp.Header.Set("WWW-Authenticate", "Bearer resource_metadata="+resourceServer.URL+"/.well-known/oauth-protected-resource/resource")

// Authorize under a context that we cancel immediately afterwards, mimicking
// the request/connect context the transport passes and that is already done
// by the time a later token refresh runs.
ctx, cancel := context.WithCancel(context.Background())
if err := handler.Authorize(ctx, req, resp); err != nil {
t.Fatalf("Authorize failed: %v", err)
}
cancel()

tokenSource, err := handler.TokenSource(context.Background())
if err != nil {
t.Fatalf("Failed to get token source: %v", err)
}
// The access token is already expired, so this forces a refresh round-trip.
// It must not fail with "context canceled" from the cancelled Authorize ctx.
token, err := tokenSource.Token()
if err != nil {
t.Fatalf("token refresh after Authorize context cancellation failed: %v", err)
}
if token.AccessToken != "test_access_token_refreshed" {
t.Errorf("expected refreshed access token %q, got %q", "test_access_token_refreshed", token.AccessToken)
}
}

func TestAuthorize_ScopeAccumulation(t *testing.T) {
authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{
RegistrationConfig: &oauthtest.RegistrationConfig{
Expand Down
48 changes: 47 additions & 1 deletion internal/oauthtest/fake_authorization_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ type Config struct {
// TokenScopeFunc, if set, is called with the scope from the authorization
// request and returns the scope string to include in the token response.
TokenScopeFunc func(requestedScope string) string
// AccessTokenTTL, if non-zero, is the expires_in (in seconds) returned by the
// /token endpoint for both the authorization_code and refresh_token grants.
// When zero a default of 3600 is used. Set it small to force a client's reuse
// token source to treat the access token as expired and refresh it.
AccessTokenTTL int
// IssueRefreshToken, if true, includes a refresh_token in token responses and
// enables grant_type=refresh_token at the /token endpoint.
IssueRefreshToken bool
}

// testRefreshToken is the refresh token issued and accepted by the fake server
// when Config.IssueRefreshToken is set.
const testRefreshToken = "test_refresh_token"

// accessTokenExpiresIn returns the expires_in value to use in token responses.
func (s *FakeAuthorizationServer) accessTokenExpiresIn() int {
if s.config.AccessTokenTTL != 0 {
return s.config.AccessTokenTTL
}
return 3600
}

// FakeAuthorizationServer is a fake OAuth 2.0 Authorization Server for testing.
Expand Down Expand Up @@ -298,6 +318,8 @@ func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Req
s.handleJWTBearerGrant(w, r)
case "client_credentials":
s.handleClientCredentialsGrant(w, r)
case "refresh_token":
s.handleRefreshTokenGrant(w, r)
default:
http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest)
}
Expand Down Expand Up @@ -329,7 +351,10 @@ func (s *FakeAuthorizationServer) handleAuthorizationCodeGrant(w http.ResponseWr
resp := map[string]any{
"access_token": "test_access_token",
"token_type": "Bearer",
"expires_in": 3600,
"expires_in": s.accessTokenExpiresIn(),
}
if s.config.IssueRefreshToken {
resp["refresh_token"] = testRefreshToken
}
if s.config.TokenScopeFunc != nil {
if scope := s.config.TokenScopeFunc(codeInfo.Scope); scope != "" {
Expand All @@ -340,6 +365,27 @@ func (s *FakeAuthorizationServer) handleAuthorizationCodeGrant(w http.ResponseWr
json.NewEncoder(w).Encode(resp)
}

// handleRefreshTokenGrant implements grant_type=refresh_token (RFC 6749 Section
// 6) when Config.IssueRefreshToken is set, returning a distinct access token so
// callers can observe that a refresh occurred.
func (s *FakeAuthorizationServer) handleRefreshTokenGrant(w http.ResponseWriter, r *http.Request) {
if !s.config.IssueRefreshToken {
http.Error(w, "refresh_token grant not supported", http.StatusBadRequest)
return
}
if r.Form.Get("refresh_token") != testRefreshToken {
http.Error(w, "invalid refresh_token", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"access_token": "test_access_token_refreshed",
"token_type": "Bearer",
"expires_in": s.accessTokenExpiresIn(),
"refresh_token": testRefreshToken,
})
}

func (s *FakeAuthorizationServer) handleJWTBearerGrant(w http.ResponseWriter, r *http.Request) {
if s.config.JWTBearerConfig == nil {
http.Error(w, "JWT bearer grant not supported", http.StatusBadRequest)
Expand Down