diff --git a/providers/telegram/session.go b/providers/telegram/session.go new file mode 100644 index 00000000..4a959ff5 --- /dev/null +++ b/providers/telegram/session.go @@ -0,0 +1,81 @@ +package telegram + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +// Session stores data during the auth process with Telegram. +type Session struct { + AuthURL string + CodeVerifier string + State string + AccessToken string + RefreshToken string + ExpiresAt time.Time + IDToken string + User goth.User +} + +// GetAuthURL will return the URL set by calling BeginAuth on the Telegram provider. +func (s Session) GetAuthURL() (string, error) { + if s.AuthURL == "" { + return "", errors.New(goth.NoAuthUrlErrorMessage) + } + return s.AuthURL, nil +} + +// Authorize the session with Telegram and store the retrieved user information. +func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { + p := provider.(*Provider) + if params.Get("state") != s.State { + return "", errors.New("invalid telegram state") + } + + token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"), oauth2.VerifierOption(s.CodeVerifier)) + if err != nil { + return "", err + } + + idToken, ok := token.Extra("id_token").(string) + if !ok || idToken == "" { + return "", errors.New("telegram id_token is empty") + } + parsed, err := parseIDToken(goth.ContextForClient(p.Client()), p.Client(), idToken, p.config.ClientID) + if err != nil { + return "", err + } + user, err := userFromClaims(parsed) + if err != nil { + return "", err + } + + s.AccessToken = token.AccessToken + s.RefreshToken = token.RefreshToken + s.ExpiresAt = token.Expiry + s.IDToken = idToken + s.User = user + return token.AccessToken, nil +} + +// Marshal the session into a string. +func (s Session) Marshal() string { + b, _ := json.Marshal(s) + return string(b) +} + +func (s Session) String() string { + return s.Marshal() +} + +// UnmarshalSession will unmarshal a JSON string into a session. +func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { + sess := &Session{} + err := json.NewDecoder(strings.NewReader(data)).Decode(sess) + return sess, err +} diff --git a/providers/telegram/telegram.go b/providers/telegram/telegram.go new file mode 100644 index 00000000..1084f7ab --- /dev/null +++ b/providers/telegram/telegram.go @@ -0,0 +1,273 @@ +// Package telegram implements OAuth2/OpenID Connect authentication for Telegram. +package telegram + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "strconv" + + "github.com/golang-jwt/jwt/v5" + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +const ( + endpointAuth = "https://oauth.telegram.org/auth" + endpointToken = "https://oauth.telegram.org/token" + endpointJWKS = "https://oauth.telegram.org/.well-known/jwks.json" + issuer = "https://oauth.telegram.org" +) + +// New creates a new Telegram provider. +func New(clientKey, secret, callbackURL string, scopes ...string) *Provider { + p := &Provider{ + ClientKey: clientKey, + Secret: secret, + CallbackURL: callbackURL, + providerName: "telegram", + } + p.config = newConfig(p, scopes) + return p +} + +// Provider is the implementation of goth.Provider for Telegram. +type Provider struct { + ClientKey string + Secret string + CallbackURL string + HTTPClient *http.Client + config *oauth2.Config + providerName string +} + +type keySet struct { + Keys []jwkKey `json:"keys"` +} + +type jwkKey struct { + KeyType string `json:"kty"` + KeyID string `json:"kid"` + Use string `json:"use"` + Alg string `json:"alg"` + Modulus string `json:"n"` + Exponent string `json:"e"` +} + +// Name is the name used to retrieve this provider later. +func (p *Provider) Name() string { + return p.providerName +} + +// SetName is to update the name of the provider. +func (p *Provider) SetName(name string) { + p.providerName = name +} + +// Client returns an HTTP client to be used in all fetch operations. +func (p *Provider) Client() *http.Client { + return goth.HTTPClientWithFallBack(p.HTTPClient) +} + +// Debug is a no-op for the telegram package. +func (p *Provider) Debug(debug bool) {} + +// BeginAuth asks Telegram for an authentication endpoint. +func (p *Provider) BeginAuth(state string) (goth.Session, error) { + verifier := oauth2.GenerateVerifier() + url := p.config.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + return &Session{ + AuthURL: url, + CodeVerifier: verifier, + State: state, + }, nil +} + +// FetchUser returns Telegram user data collected from the ID token. +func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { + sess, ok := session.(*Session) + if !ok { + return goth.User{}, fmt.Errorf("invalid telegram session") + } + if sess.User.UserID == "" { + return goth.User{}, fmt.Errorf("telegram user is empty") + } + return sess.User, nil +} + +// RefreshTokenAvailable returns whether Telegram supports refresh tokens. +func (p *Provider) RefreshTokenAvailable() bool { + return false +} + +// RefreshToken gets a new access token based on a refresh token. +func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + token := &oauth2.Token{RefreshToken: refreshToken} + ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token) + return ts.Token() +} + +func newConfig(provider *Provider, scopes []string) *oauth2.Config { + if len(scopes) == 0 { + scopes = []string{"openid", "profile"} + } + return &oauth2.Config{ + ClientID: provider.ClientKey, + ClientSecret: provider.Secret, + RedirectURL: provider.CallbackURL, + Scopes: scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: endpointAuth, + TokenURL: endpointToken, + }, + } +} + +func parseIDToken(ctx context.Context, client *http.Client, idToken, clientID string) (jwt.MapClaims, error) { + keys, err := fetchKeySet(ctx, client) + if err != nil { + return nil, err + } + claims := jwt.MapClaims{} + token, err := jwt.ParseWithClaims( + idToken, + claims, + keys.keyfunc, + jwt.WithIssuer(issuer), + jwt.WithAudience(clientID), + jwt.WithExpirationRequired(), + ) + if err != nil { + return nil, fmt.Errorf("parse telegram id_token: %w", err) + } + if !token.Valid { + return nil, fmt.Errorf("telegram id_token is invalid") + } + return claims, nil +} + +func fetchKeySet(ctx context.Context, client *http.Client) (*keySet, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpointJWKS, nil) + if err != nil { + return nil, err + } + resp, err := goth.HTTPClientWithFallBack(client).Do(req) + if err != nil { + return nil, fmt.Errorf("fetch telegram jwks: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("fetch telegram jwks: status %d", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read telegram jwks: %w", err) + } + keys := &keySet{} + if err := json.Unmarshal(body, keys); err != nil { + return nil, fmt.Errorf("decode telegram jwks: %w", err) + } + return keys, nil +} + +func (ks *keySet) keyfunc(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected telegram signing method: %s", token.Header["alg"]) + } + kid, _ := token.Header["kid"].(string) + for _, key := range ks.Keys { + if key.KeyType != "RSA" || key.Modulus == "" || key.Exponent == "" { + continue + } + if kid != "" && key.KeyID != kid { + continue + } + if key.Alg != "" && key.Alg != token.Method.Alg() { + continue + } + publicKey, err := key.rsaPublicKey() + if err != nil { + return nil, err + } + return publicKey, nil + } + return nil, fmt.Errorf("telegram jwk not found") +} + +func (key jwkKey) rsaPublicKey() (*rsa.PublicKey, error) { + modulus, err := base64.RawURLEncoding.DecodeString(key.Modulus) + if err != nil { + return nil, fmt.Errorf("decode telegram jwk modulus: %w", err) + } + exponent, err := base64.RawURLEncoding.DecodeString(key.Exponent) + if err != nil { + return nil, fmt.Errorf("decode telegram jwk exponent: %w", err) + } + if len(exponent) == 0 { + return nil, fmt.Errorf("telegram jwk exponent is empty") + } + e := 0 + for _, b := range exponent { + e = e<<8 + int(b) + } + return &rsa.PublicKey{N: new(big.Int).SetBytes(modulus), E: e}, nil +} + +func userFromClaims(claims jwt.MapClaims) (goth.User, error) { + subject, err := claims.GetSubject() + if err != nil { + return goth.User{}, fmt.Errorf("read telegram subject: %w", err) + } + if subject == "" { + return goth.User{}, fmt.Errorf("telegram subject is empty") + } + raw, err := json.Marshal(claims) + if err != nil { + return goth.User{}, fmt.Errorf("marshal telegram id_token claims: %w", err) + } + + return goth.User{ + Provider: "telegram", + UserID: subject, + Name: stringClaim(claims, "name"), + NickName: stringClaim(claims, "preferred_username"), + AvatarURL: stringClaim(claims, "picture"), + RawData: map[string]interface{}{ + "raw_profile": string(raw), + "telegram_id": telegramIDClaim(claims), + }, + }, nil +} + +func stringClaim(claims jwt.MapClaims, key string) string { + value, ok := claims[key] + if !ok { + return "" + } + s, _ := value.(string) + return s +} + +func telegramIDClaim(claims jwt.MapClaims) string { + value, ok := claims["id"] + if !ok { + return "" + } + switch v := value.(type) { + case string: + return v + case int: + return strconv.Itoa(v) + case int64: + return strconv.FormatInt(v, 10) + case float64: + return strconv.FormatInt(int64(v), 10) + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/providers/telegram/telegram_test.go b/providers/telegram/telegram_test.go new file mode 100644 index 00000000..deb26bc3 --- /dev/null +++ b/providers/telegram/telegram_test.go @@ -0,0 +1,68 @@ +package telegram + +import ( + "net/url" + "strings" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func TestBeginAuthUsesTelegramOIDCEndpointAndPKCE(t *testing.T) { + provider := New("client", "secret", "https://example.com/callback") + session, err := provider.BeginAuth("state-value") + require.NoError(t, err) + + authURL, err := session.GetAuthURL() + require.NoError(t, err) + + parsed, err := url.Parse(authURL) + require.NoError(t, err) + require.Equal(t, endpointAuth, parsed.Scheme+"://"+parsed.Host+parsed.Path) + require.Equal(t, "client", parsed.Query().Get("client_id")) + require.Equal(t, "https://example.com/callback", parsed.Query().Get("redirect_uri")) + require.Equal(t, "state-value", parsed.Query().Get("state")) + require.Equal(t, "S256", parsed.Query().Get("code_challenge_method")) + require.NotEmpty(t, parsed.Query().Get("code_challenge")) + require.True(t, strings.Contains(parsed.Query().Get("scope"), "openid")) + require.True(t, strings.Contains(parsed.Query().Get("scope"), "profile")) + + telegramSession := session.(*Session) + require.Equal(t, "state-value", telegramSession.State) + require.NotEmpty(t, telegramSession.CodeVerifier) +} + +func TestSessionMarshalRoundTrip(t *testing.T) { + provider := New("client", "secret", "https://example.com/callback") + session := &Session{AuthURL: "https://oauth.telegram.org/auth", CodeVerifier: "verifier", State: "state"} + + unmarshaled, err := provider.UnmarshalSession(session.Marshal()) + require.NoError(t, err) + + telegramSession := unmarshaled.(*Session) + require.Equal(t, session.AuthURL, telegramSession.AuthURL) + require.Equal(t, session.CodeVerifier, telegramSession.CodeVerifier) + require.Equal(t, session.State, telegramSession.State) +} + +func TestUserFromClaims(t *testing.T) { + claims := jwt.MapClaims{ + "sub": "telegram-subject", + "id": int64(12345), + "name": "Telegram User", + "preferred_username": "tguser", + "picture": "https://example.com/avatar.jpg", + } + + user, err := userFromClaims(claims) + require.NoError(t, err) + + require.Equal(t, "telegram", user.Provider) + require.Equal(t, "telegram-subject", user.UserID) + require.Equal(t, "Telegram User", user.Name) + require.Equal(t, "tguser", user.NickName) + require.Equal(t, "https://example.com/avatar.jpg", user.AvatarURL) + require.Equal(t, "12345", user.RawData["telegram_id"]) + require.NotEmpty(t, user.RawData["raw_profile"]) +}