diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index d219e5653..bdf0b9b25 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -394,6 +394,13 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon grantParams.Scopes = &scopes err = db.Transaction(func(tx *storage.Connection) error { + if _, terr := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorization.AuthorizationID); terr != nil { + if models.IsNotFoundError(terr) { + return apierrors.NewOAuthError("invalid_grant", "Invalid authorization code") + } + return apierrors.NewInternalServerError("Error locking authorization code").WithInternalError(terr) + } + authMethod := models.OAuthProviderAuthorizationCode // Create audit log entry for OAuth token exchange @@ -424,6 +431,9 @@ func (s *Server) handleAuthorizationCodeGrant(ctx context.Context, w http.Respon if httpErr, ok := err.(*apierrors.HTTPError); ok { return httpErr } + if oauthErr, ok := err.(*apierrors.OAuthError); ok { + return oauthErr + } return apierrors.NewInternalServerError("Error exchanging authorization code").WithInternalError(err) } diff --git a/internal/api/oauthserver/handlers_replay_test.go b/internal/api/oauthserver/handlers_replay_test.go new file mode 100644 index 000000000..3c2fa891e --- /dev/null +++ b/internal/api/oauthserver/handlers_replay_test.go @@ -0,0 +1,65 @@ +package oauthserver + +import ( + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/shared" + "github.com/supabase/auth/internal/models" +) + +func (ts *OAuthClientTestSuite) mintApprovedCode(clientID, userID uuid.UUID) string { + auth := models.NewOAuthServerAuthorization(models.NewOAuthServerAuthorizationParams{ + ClientID: clientID, + RedirectURI: "https://example.com/callback", + Scope: "profile", + TTL: time.Hour, + }) + require.NoError(ts.T(), models.CreateOAuthServerAuthorization(ts.DB, auth)) + require.NoError(ts.T(), auth.SetUser(ts.DB, userID)) + require.NoError(ts.T(), auth.Approve(ts.DB)) + return *auth.AuthorizationCode +} + +func (ts *OAuthClientTestSuite) TestAuthCodeReplayRace() { + client, _ := ts.createTestOAuthClient() + user := ts.createTestUser("replay-race@example.com") + code := ts.mintApprovedCode(client.ID, user.ID) + + const n = 10 + start := make(chan struct{}) + var wg sync.WaitGroup + var success int32 + + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + req := httptest.NewRequest(http.MethodPost, "/oauth/token", nil) + ctx := shared.WithOAuthServerClient(req.Context(), client) + req = req.WithContext(ctx) + w := httptest.NewRecorder() + params := &OAuthTokenParams{GrantType: GrantTypeAuthorizationCode, Code: code} + + <-start + if err := ts.Server.handleAuthorizationCodeGrant(ctx, w, req, params); err == nil { + atomic.AddInt32(&success, 1) + } + }() + } + + close(start) + wg.Wait() + + assert.Equal(ts.T(), int32(1), success, "authorization code must be single-use: expected exactly one successful redemption, got %d", success) + + _, err := models.FindOAuthServerAuthorizationByCode(ts.DB, code) + assert.True(ts.T(), models.IsNotFoundError(err), "authorization code should be consumed after redemption") +}