diff --git a/internal/api/token_access_token.go b/internal/api/token_access_token.go index 0f693dd53..38338ab44 100644 --- a/internal/api/token_access_token.go +++ b/internal/api/token_access_token.go @@ -15,8 +15,9 @@ import ( // AccessTokenGrantParams are the parameters the AccessTokenGrant method accepts type AccessTokenGrantParams struct { - Provider string `json:"provider"` - AccessToken string `json:"access_token"` + Provider string `json:"provider"` + AccessToken string `json:"access_token"` + LinkIdentity bool `json:"link_identity"` } // AccessTokenGrant implements the access_token grant type flow, which allows @@ -28,6 +29,10 @@ type AccessTokenGrantParams struct { // an OIDC id token (AuthenticationToken) on the first authorization, which // makes the id_token grant unusable for repeat logins without falling back to // the browser flow. +// +// When link_identity is set and a valid user access token is provided in the +// Authorization header, the provider identity is linked to that user instead of +// signing in or creating a new account. func (a *API) AccessTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) @@ -44,6 +49,25 @@ func (a *API) AccessTokenGrant(ctx context.Context, w http.ResponseWriter, r *ht return apierrors.NewOAuthError("invalid request", "provider required") } + if params.LinkIdentity { + if r.Header.Get("Authorization") == "" { + return apierrors.NewOAuthError("invalid request", "Linking requires a valid user access token in Authorization") + } + + requireAuthCtx, err := a.requireAuthentication(w, r) + if err != nil { + return err + } + + targetUser := getUser(requireAuthCtx) + if targetUser == nil { + return apierrors.NewOAuthError("invalid request", "Linking requires a valid user authentication") + } + + // set it so linkIdentityToUser works below + ctx = withTargetUser(ctx, targetUser) + } + oauthProvider, pConfig, err := a.OAuthProvider(ctx, params.Provider) if err != nil { return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %q", params.Provider).WithInternalError(err) @@ -84,8 +108,10 @@ func (a *API) AccessTokenGrant(ctx context.Context, w http.ResponseWriter, r *ht var grantParams models.GrantParams grantParams.FillGrantParams(r) - if err := a.triggerBeforeUserCreatedExternal(r, db, userData, params.Provider); err != nil { - return err + if !params.LinkIdentity { + if err := a.triggerBeforeUserCreatedExternal(r, db, userData, params.Provider); err != nil { + return err + } } var createdUser bool @@ -95,7 +121,11 @@ func (a *API) AccessTokenGrant(ctx context.Context, w http.ResponseWriter, r *ht var terr error var decision models.AccountLinkingDecision - decision, user, terr = a.createAccountFromExternalIdentity(tx, r, userData, params.Provider, pConfig.EmailOptional) + if params.LinkIdentity { + user, terr = a.linkIdentityToUser(r, ctx, tx, userData, params.Provider) + } else { + decision, user, terr = a.createAccountFromExternalIdentity(tx, r, userData, params.Provider, pConfig.EmailOptional) + } if terr != nil { return terr } diff --git a/internal/api/token_access_token_test.go b/internal/api/token_access_token_test.go index dd20c0f1f..620914d7f 100644 --- a/internal/api/token_access_token_test.go +++ b/internal/api/token_access_token_test.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "net/http/httptest" + + "github.com/supabase/auth/internal/models" ) // FacebookAccessTokenSetup spins up a mock Graph API that answers the @@ -146,3 +148,65 @@ func (ts *ExternalTestSuite) TestAccessTokenGrantSignupDisabled() { w := ts.accessTokenGrant("facebook", "valid_access_token") ts.Require().Equal(http.StatusUnprocessableEntity, w.Code) } + +func (ts *ExternalTestSuite) generateAccessTokenAndSession(u *models.User) string { + s, err := models.NewSession(u.ID, nil) + ts.Require().NoError(err) + ts.Require().NoError(ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + ts.Require().NoError(err) + return token +} + +func (ts *ExternalTestSuite) accessTokenLinkGrant(provider, accessToken, authToken string) *httptest.ResponseRecorder { + var buffer bytes.Buffer + ts.Require().NoError(json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "provider": provider, + "access_token": accessToken, + "link_identity": true, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=access_token", &buffer) + req.Header.Set("Content-Type", "application/json") + if authToken != "" { + req.Header.Set("Authorization", "Bearer "+authToken) + } + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + return w +} + +func (ts *ExternalTestSuite) TestAccessTokenGrantLinkIdentity() { + user, err := ts.createUser("link-user-sub", "linkme@example.com", "Link Me", "", "") + ts.Require().NoError(err) + token := ts.generateAccessTokenAndSession(user) + + server := FacebookAccessTokenSetup(ts, ts.Config.External.Facebook.ClientID[0], true, "USER", facebookUser) + defer server.Close() + + w := ts.accessTokenLinkGrant("facebook", "valid_access_token", token) + ts.Require().Equal(http.StatusOK, w.Code, w.Body.String()) + + var response AccessTokenResponse + ts.Require().NoError(json.NewDecoder(w.Body).Decode(&response)) + ts.Require().NotNil(response.User) + // Linking attaches the facebook identity to the already-authenticated user. + ts.Equal(user.ID, response.User.ID) + + ts.Require().NoError(ts.API.db.Load(user, "Identities")) + var providers []string + for _, identity := range user.Identities { + providers = append(providers, identity.Provider) + } + ts.Contains(providers, "facebook") +} + +func (ts *ExternalTestSuite) TestAccessTokenGrantLinkIdentityRequiresAuth() { + server := FacebookAccessTokenSetup(ts, ts.Config.External.Facebook.ClientID[0], true, "USER", facebookUser) + defer server.Close() + + w := ts.accessTokenLinkGrant("facebook", "valid_access_token", "") + ts.Require().Equal(http.StatusBadRequest, w.Code) +}