Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 22 additions & 25 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ const (
oauthErrInvalidToken = "invalid_token"
)

// Token store backend modes for the -token-store flag
const (
tokenStoreModeAuto = "auto"
tokenStoreModeFile = "file"
tokenStoreModeKeyring = "keyring"
)

// tokenResponse is the common structure for OAuth token endpoint responses.
type tokenResponse struct {
AccessToken string `json:"access_token"`
Expand Down Expand Up @@ -208,14 +215,15 @@ func doInitConfig() {
// Initialize token store based on mode
fileStore := credstore.NewTokenFileStore(tokenFile)
switch tokenStoreMode {
case "file":
case tokenStoreModeFile:
tokenStore = fileStore
case "keyring":
case tokenStoreModeKeyring:
tokenStore = credstore.NewTokenKeyringStore(defaultKeyringService)
case "auto":
case tokenStoreModeAuto:
kr := credstore.NewTokenKeyringStore(defaultKeyringService)
tokenStore = credstore.NewSecureStore[credstore.Token](kr, fileStore)
if !tokenStore.(*credstore.SecureStore[credstore.Token]).UseKeyring() {
secureStore := credstore.NewSecureStore(kr, fileStore)
tokenStore = secureStore
if !secureStore.UseKeyring() {
fmt.Fprintln(
os.Stderr,
"⚠️ OS keyring unavailable, falling back to file-based token storage",
Expand All @@ -224,8 +232,8 @@ func doInitConfig() {
default:
fmt.Fprintf(
os.Stderr,
"Error: Invalid token-store value: %s (must be auto, file, or keyring)\n",
tokenStoreMode,
"Error: Invalid token-store value: %s (must be %s, %s, or %s)\n",
tokenStoreMode, tokenStoreModeAuto, tokenStoreModeFile, tokenStoreModeKeyring,
)
os.Exit(1)
}
Expand Down Expand Up @@ -432,7 +440,6 @@ func run(ctx context.Context, d tui.Displayer) error {

// requestDeviceCode requests a device code from the OAuth server with retry logic
func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) {
// Create request with timeout
reqCtx, cancel := context.WithTimeout(ctx, deviceCodeRequestTimeout)
defer cancel()

Expand All @@ -451,7 +458,6 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

// Execute request with retry logic
resp, err := retryClient.DoWithContext(reqCtx, req)
if err != nil {
return nil, fmt.Errorf("device code request failed: %w", err)
Expand All @@ -460,7 +466,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)

body, err := readResponseBody(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
return nil, err
}

if resp.StatusCode != http.StatusOK {
Expand All @@ -471,7 +477,6 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
)
}

// Parse response
var deviceResp struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
Expand Down Expand Up @@ -627,7 +632,6 @@ func exchangeDeviceCode(
ctx context.Context,
tokenURL, clientID, deviceCode string,
) (*oauth2.Token, error) {
// Create request with timeout
reqCtx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout)
defer cancel()

Expand All @@ -649,24 +653,22 @@ func exchangeDeviceCode(

resp, err := retryClient.DoWithContext(reqCtx, req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
return nil, fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()

body, err := readResponseBody(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
return nil, err
}

// Handle non-200 responses
if resp.StatusCode != http.StatusOK {
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
}
}

// Parse successful token response
var tokenResp tokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
Expand All @@ -692,7 +694,6 @@ func exchangeDeviceCode(
}

func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error {
// Create request with timeout
reqCtx, cancel := context.WithTimeout(ctx, tokenVerificationTimeout)
defer cancel()

Expand All @@ -704,16 +705,15 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error
}
req.Header.Set("Authorization", "Bearer "+accessToken)

// Execute request with retry logic
resp, err := retryClient.DoWithContext(reqCtx, req)
if err != nil {
return fmt.Errorf("request failed: %w", err)
return fmt.Errorf("token verification request failed: %w", err)
}
defer resp.Body.Close()

body, err := readResponseBody(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
return err
}

if resp.StatusCode != http.StatusOK {
Expand All @@ -734,7 +734,6 @@ func refreshAccessToken(
refreshToken string,
d tui.Displayer,
) (credstore.Token, error) {
// Create request with timeout
reqCtx, cancel := context.WithTimeout(ctx, refreshTokenTimeout)
defer cancel()

Expand All @@ -754,7 +753,6 @@ func refreshAccessToken(
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

// Execute request with retry logic
resp, err := retryClient.DoWithContext(reqCtx, req)
if err != nil {
return credstore.Token{}, fmt.Errorf("refresh request failed: %w", err)
Expand All @@ -763,7 +761,7 @@ func refreshAccessToken(

body, err := readResponseBody(resp.Body)
if err != nil {
return credstore.Token{}, fmt.Errorf("failed to read response: %w", err)
return credstore.Token{}, err
}

if resp.StatusCode != http.StatusOK {
Expand All @@ -782,7 +780,6 @@ func refreshAccessToken(
)
}

// Parse token response
var tokenResp tokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return credstore.Token{}, fmt.Errorf("failed to parse token response: %w", err)
Expand Down Expand Up @@ -891,7 +888,7 @@ func makeAPICallWithAutoRefresh(

body, err := readResponseBody(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response: %w", err)
return err
}

if resp.StatusCode != http.StatusOK {
Expand Down
Loading