From e0db3ca12809c70312e1e21528002aff37849a79 Mon Sep 17 00:00:00 2001 From: kanywst Date: Thu, 2 Jul 2026 00:47:41 +0900 Subject: [PATCH] fix(oauthserver): consume authorization code under a row lock so it stays single-use The authorization_code grant looks up the code outside the transaction and without a lock, then issues the token and deletes the row inside the transaction, so concurrent requests with the same code each mint a token. Re-acquire and consume the code under FOR UPDATE SKIP LOCKED before issuing, matching the consent path. --- .../authorization_code_race_test.go | 93 +++++++++++++++++++ internal/api/oauthserver/handlers.go | 15 ++- internal/models/oauth_authorization.go | 26 ++++++ 3 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 internal/api/oauthserver/authorization_code_race_test.go diff --git a/internal/api/oauthserver/authorization_code_race_test.go b/internal/api/oauthserver/authorization_code_race_test.go new file mode 100644 index 000000000..1e211a9c6 --- /dev/null +++ b/internal/api/oauthserver/authorization_code_race_test.go @@ -0,0 +1,93 @@ +package oauthserver + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/shared" + "github.com/supabase/auth/internal/models" +) + +func (ts *OAuthClientTestSuite) TestAuthorizationCodeSingleUseUnderConcurrency() { + client, _ := ts.createTestOAuthClient() + client.SetGrantTypes([]string{"authorization_code", "refresh_token"}) + require.NoError(ts.T(), ts.DB.UpdateOnly(client, "grant_types")) + + user := ts.createTestUser("code-race@example.com") + + verifier := "code-race-verifier-0123456789012345678901234567890" + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sum[:]) + method := "s256" + + code := "code-race-single-use-abcdef0123456789" + redirectURI := "https://example.com/callback" + + auth := &models.OAuthServerAuthorization{ + ID: uuid.Must(uuid.NewV4()), + AuthorizationID: uuid.Must(uuid.NewV4()).String(), + ClientID: client.ID, + UserID: &user.ID, + RedirectURI: redirectURI, + Scope: "email", // non-openid, skips ID-token signing + CodeChallenge: &challenge, + CodeChallengeMethod: &method, + ResponseType: models.OAuthServerResponseTypeCode, + Status: models.OAuthServerAuthorizationApproved, + AuthorizationCode: &code, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + } + require.NoError(ts.T(), ts.DB.Create(auth)) + + const n = 8 + var wg sync.WaitGroup + var mu sync.Mutex + tokens := map[string]bool{} + start := make(chan struct{}) + + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("code_verifier", verifier) + form.Set("redirect_uri", redirectURI) + + req := httptest.NewRequest(http.MethodPost, "/oauth/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(shared.WithOAuthServerClient(req.Context(), client)) + w := httptest.NewRecorder() + + <-start + if err := ts.Server.OAuthToken(w, req); err != nil || w.Code != http.StatusOK { + return + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + return + } + if tok, _ := resp["access_token"].(string); tok != "" { + mu.Lock() + tokens[tok] = true + mu.Unlock() + } + }() + } + close(start) + wg.Wait() + + require.Equal(ts.T(), 1, len(tokens), + "one authorization code must not mint more than one token, got %d", len(tokens)) +} diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index d219e5653..2384ee722 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -396,6 +396,15 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon err = db.Transaction(func(tx *storage.Connection) error { authMethod := models.OAuthProviderAuthorizationCode + // Consume the code under a row lock to keep it single use across concurrent requests + lockedAuthorization, terr := models.FindOAuthServerAuthorizationByCodeForUpdate(tx, params.Code) + if terr != nil { + if models.IsNotFoundError(terr) { + return apierrors.NewOAuthError("invalid_grant", "Invalid authorization code") + } + return terr + } + // Create audit log entry for OAuth token exchange if terr := models.NewAuditLogEntry(s.config.AuditLog, r, tx, user, models.LoginAction, "", map[string]interface{}{ "provider_type": "oauth_provider_authorization_code", @@ -405,7 +414,6 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon } // Issue the refresh token and access token - var terr error tokenResponse, terr = tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams) if terr != nil { return terr @@ -413,7 +421,7 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon // Mark authorization as used - authorization codes are single use // We could either delete it or mark it as consumed - if terr = tx.Destroy(authorization); terr != nil { + if terr = tx.Destroy(lockedAuthorization); terr != nil { return terr } @@ -424,6 +432,9 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon if httpErr, ok := err.(*apierrors.HTTPError); ok { return httpErr } + if oauthErr, ok := err.(*apierrors.OAuthError); ok { + return oauthErr + } return apierrors.NewInternalServerError("Error exchanging authorization code").WithInternalError(err) } diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 2b7fab257..ce118398e 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -269,6 +269,32 @@ func FindOAuthServerAuthorizationByIDForUpdate(tx *storage.Connection, authoriza return auth, nil } +// FindOAuthServerAuthorizationByCodeForUpdate finds an approved OAuth +// authorization by authorization code and locks the row with FOR UPDATE SKIP +// LOCKED. Must be called inside a transaction. +func FindOAuthServerAuthorizationByCodeForUpdate(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) { + auth := &OAuthServerAuthorization{} + if err := tx.RawQuery( + fmt.Sprintf("SELECT * FROM %q WHERE authorization_code = ? AND status = ? LIMIT 1 FOR UPDATE SKIP LOCKED", auth.TableName()), + code, OAuthServerAuthorizationApproved, + ).First(auth); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerAuthorizationNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth authorization by code") + } + + // Load client relationship (always present) + if auth.ClientID != uuid.Nil { + client := &OAuthServerClient{} + if err := tx.Q().Where("id = ?", auth.ClientID).First(client); err == nil { + auth.Client = client + } + } + + return auth, nil +} + // FindOAuthServerAuthorizationByCode finds an OAuth authorization by authorization code func FindOAuthServerAuthorizationByCode(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) { auth := &OAuthServerAuthorization{}