Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions internal/api/oauthserver/authorization_code_race_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package oauthserver

import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"time"

"github.com/gofrs/uuid"
"github.com/stretchr/testify/require"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
)

func (ts *OAuthClientTestSuite) TestAuthorizationCodeSingleUseUnderConcurrency() {
client, _ := ts.createTestOAuthClient()
client.SetGrantTypes([]string{"authorization_code", "refresh_token"})
require.NoError(ts.T(), ts.DB.UpdateOnly(client, "grant_types"))

user := ts.createTestUser("code-race@example.com")

verifier := "code-race-verifier-0123456789012345678901234567890"
sum := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
method := "s256"

code := "code-race-single-use-abcdef0123456789"
redirectURI := "https://example.com/callback"

auth := &models.OAuthServerAuthorization{
ID: uuid.Must(uuid.NewV4()),
AuthorizationID: uuid.Must(uuid.NewV4()).String(),
ClientID: client.ID,
UserID: &user.ID,
RedirectURI: redirectURI,
Scope: "email", // non-openid, skips ID-token signing
CodeChallenge: &challenge,
CodeChallengeMethod: &method,
ResponseType: models.OAuthServerResponseTypeCode,
Status: models.OAuthServerAuthorizationApproved,
AuthorizationCode: &code,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(10 * time.Minute),
}
require.NoError(ts.T(), ts.DB.Create(auth))

const n = 8
var wg sync.WaitGroup
var mu sync.Mutex
tokens := map[string]bool{}
start := make(chan struct{})

for i := 0; i < n; i++ {
wg.Add(1)
go func() {
defer wg.Done()
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("code_verifier", verifier)
form.Set("redirect_uri", redirectURI)

req := httptest.NewRequest(http.MethodPost, "/oauth/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(shared.WithOAuthServerClient(req.Context(), client))
w := httptest.NewRecorder()

<-start
if err := ts.Server.OAuthToken(w, req); err != nil || w.Code != http.StatusOK {
return
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
return
}
if tok, _ := resp["access_token"].(string); tok != "" {
mu.Lock()
tokens[tok] = true
mu.Unlock()
}
}()
}
close(start)
wg.Wait()

require.Equal(ts.T(), 1, len(tokens),
"one authorization code must not mint more than one token, got %d", len(tokens))
}
15 changes: 13 additions & 2 deletions internal/api/oauthserver/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,15 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon
err = db.Transaction(func(tx *storage.Connection) error {
authMethod := models.OAuthProviderAuthorizationCode

// Consume the code under a row lock to keep it single use across concurrent requests
lockedAuthorization, terr := models.FindOAuthServerAuthorizationByCodeForUpdate(tx, params.Code)
if terr != nil {
if models.IsNotFoundError(terr) {
return apierrors.NewOAuthError("invalid_grant", "Invalid authorization code")
}
return terr
}

// Create audit log entry for OAuth token exchange
if terr := models.NewAuditLogEntry(s.config.AuditLog, r, tx, user, models.LoginAction, "", map[string]interface{}{
"provider_type": "oauth_provider_authorization_code",
Expand All @@ -405,15 +414,14 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon
}

// Issue the refresh token and access token
var terr error
tokenResponse, terr = tokenService.IssueRefreshToken(r, w.Header(), tx, user, authMethod, grantParams)
if terr != nil {
return terr
}

// Mark authorization as used - authorization codes are single use
// We could either delete it or mark it as consumed
if terr = tx.Destroy(authorization); terr != nil {
if terr = tx.Destroy(lockedAuthorization); terr != nil {
return terr
}

Expand All @@ -424,6 +432,9 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon
if httpErr, ok := err.(*apierrors.HTTPError); ok {
return httpErr
}
if oauthErr, ok := err.(*apierrors.OAuthError); ok {
return oauthErr
}
return apierrors.NewInternalServerError("Error exchanging authorization code").WithInternalError(err)
}

Expand Down
26 changes: 26 additions & 0 deletions internal/models/oauth_authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ func FindOAuthServerAuthorizationByIDForUpdate(tx *storage.Connection, authoriza
return auth, nil
}

// FindOAuthServerAuthorizationByCodeForUpdate finds an approved OAuth
// authorization by authorization code and locks the row with FOR UPDATE SKIP
// LOCKED. Must be called inside a transaction.
func FindOAuthServerAuthorizationByCodeForUpdate(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) {
auth := &OAuthServerAuthorization{}
if err := tx.RawQuery(
fmt.Sprintf("SELECT * FROM %q WHERE authorization_code = ? AND status = ? LIMIT 1 FOR UPDATE SKIP LOCKED", auth.TableName()),
code, OAuthServerAuthorizationApproved,
).First(auth); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, OAuthServerAuthorizationNotFoundError{}
}
return nil, errors.Wrap(err, "error finding OAuth authorization by code")
}

// Load client relationship (always present)
if auth.ClientID != uuid.Nil {
client := &OAuthServerClient{}
if err := tx.Q().Where("id = ?", auth.ClientID).First(client); err == nil {
auth.Client = client
}
}

return auth, nil
}

// FindOAuthServerAuthorizationByCode finds an OAuth authorization by authorization code
func FindOAuthServerAuthorizationByCode(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) {
auth := &OAuthServerAuthorization{}
Expand Down