Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 82 additions & 8 deletions internal/oauth/refresh_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ const (
// 0.75 means refresh at 75% of lifetime for long-lived tokens.
DefaultRefreshThreshold = 0.75

// DefaultMaxRetries is the maximum number of refresh attempts before giving up.
// Set to 0 for unlimited retries until token expiration (FR-009).
DefaultMaxRetries = 0
// DefaultMaxRetries is the maximum number of consecutive refresh attempts before giving up.
// Acts as a circuit breaker to prevent infinite retry loops (issue #310).
// With exponential backoff (10s base, 5min cap), 50 retries spans ~2+ hours.
DefaultMaxRetries = 50

// MinRefreshInterval prevents too-frequent refresh attempts.
MinRefreshInterval = 5 * time.Second
Expand Down Expand Up @@ -672,7 +673,8 @@ func (m *RefreshManager) handleRefreshSuccess(serverName string) {
}

// handleRefreshFailure handles a failed token refresh with exponential backoff retry.
// Per FR-009: Retries continue until token expiration (unlimited retries), not a fixed count.
// Terminal errors (invalid_grant, server not found) stop immediately.
// Transient errors retry with exponential backoff up to maxRetries.
func (m *RefreshManager) handleRefreshFailure(serverName string, err error) {
m.mu.Lock()
schedule := m.schedules[serverName]
Expand Down Expand Up @@ -715,8 +717,48 @@ func (m *RefreshManager) handleRefreshFailure(serverName string, err error) {
return
}

// Check if we should continue retrying (unlimited retries until token expiration per FR-009)
// Only stop if the access token has completely expired AND no more time remains
// Check if the server is gone (removed from config or not OAuth).
// No amount of retrying will fix this - stop immediately.
if errorType == "failed_server_gone" {
m.logger.Error("OAuth refresh stopped - server no longer available",
zap.String("server", serverName),
zap.Error(err))

m.mu.Lock()
if schedule := m.schedules[serverName]; schedule != nil {
schedule.RefreshState = RefreshStateFailed
schedule.LastError = err.Error()
}
m.mu.Unlock()

if m.eventEmitter != nil {
m.eventEmitter.EmitOAuthRefreshFailed(serverName, err.Error())
}
return
}

// Check if max retries exceeded (circuit breaker to prevent infinite loops).
if m.maxRetries > 0 && retryCount >= m.maxRetries {
m.logger.Error("OAuth token refresh failed - max retries exceeded",
zap.String("server", serverName),
zap.Int("max_retries", m.maxRetries),
zap.Int("retry_count", retryCount))

m.mu.Lock()
if schedule := m.schedules[serverName]; schedule != nil {
schedule.RefreshState = RefreshStateFailed
schedule.LastError = "Max retries exceeded - re-authentication required"
}
m.mu.Unlock()

if m.eventEmitter != nil {
m.eventEmitter.EmitOAuthRefreshFailed(serverName, err.Error())
}
return
}

// Check if we should continue retrying based on token expiration.
// Only stop if the access token has completely expired AND no more time remains.
now := time.Now()
if !expiresAt.IsZero() && now.After(expiresAt) {
// Token has already expired - check if we should give up
Expand Down Expand Up @@ -748,14 +790,26 @@ func (m *RefreshManager) handleRefreshFailure(serverName string, err error) {
}

// classifyRefreshError categorizes a refresh error for metrics and error handling.
// Returns one of: "failed_network", "failed_invalid_grant", "failed_other".
// Returns one of: "failed_network", "failed_invalid_grant", "failed_server_gone", "failed_other".
func classifyRefreshError(err error) string {
if err == nil {
return "success"
}

errStr := err.Error()

// Check for terminal server-gone errors (server removed from config or not OAuth).
// These should never be retried because the server no longer exists or doesn't use OAuth.
serverGoneErrors := []string{
"server not found",
"server does not use OAuth",
}
for _, pattern := range serverGoneErrors {
if stringutil.ContainsIgnoreCase(errStr, pattern) {
return "failed_server_gone"
}
}

// Check for permanent OAuth errors (refresh token invalid/expired)
permanentErrors := []string{
"invalid_grant",
Expand Down Expand Up @@ -789,15 +843,29 @@ func classifyRefreshError(err error) string {
return "failed_other"
}

// maxBackoffExponent is the maximum shift exponent that won't overflow when
// multiplied by RetryBackoffBase (10s = 10_000_000_000 ns). On 64-bit systems,
// 1<<30 * 10e9 overflows int64. We cap at 25 which gives 10s * 2^25 = 335,544,320s
// — well above MaxRetryBackoff (300s), so the cap applies naturally.
const maxBackoffExponent = 25

// calculateBackoff calculates the exponential backoff duration for a given retry count.
// The formula is: base * 2^retryCount, capped at MaxRetryBackoff (5 minutes).
// Sequence: 10s → 20s → 40s → 80s → 160s → 300s (cap).
//
// The exponent is capped at maxBackoffExponent to prevent integer overflow
// that caused issue #310 (0s delay at high retry counts leading to infinite loops).
func (m *RefreshManager) calculateBackoff(retryCount int) time.Duration {
if retryCount < 0 {
retryCount = 0
}
// Cap the exponent to prevent integer overflow.
// Beyond this exponent, the result would exceed MaxRetryBackoff anyway.
if retryCount > maxBackoffExponent {
return MaxRetryBackoff
}
backoff := RetryBackoffBase * time.Duration(1<<uint(retryCount))
if backoff > MaxRetryBackoff {
if backoff <= 0 || backoff > MaxRetryBackoff {
backoff = MaxRetryBackoff
}
return backoff
Expand All @@ -814,7 +882,13 @@ func (m *RefreshManager) isRateLimited(schedule *RefreshSchedule) bool {
}

// rescheduleAfterDelay reschedules a refresh attempt after a delay.
// The delay is enforced to be at least MinRefreshInterval to prevent tight loops.
func (m *RefreshManager) rescheduleAfterDelay(serverName string, delay time.Duration) {
// Enforce minimum delay to prevent tight retry loops (defense-in-depth for issue #310).
if delay < MinRefreshInterval {
delay = MinRefreshInterval
}

m.mu.Lock()
defer m.mu.Unlock()

Expand Down
167 changes: 167 additions & 0 deletions internal/oauth/refresh_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,10 @@ func TestClassifyRefreshError(t *testing.T) {
{"context deadline exceeded", errors.New("context deadline exceeded"), "failed_network"},
{"generic error", errors.New("unknown server error"), "failed_other"},
{"server_error", errors.New("OAuth server_error"), "failed_other"},
// Terminal errors: server gone
{"server not found", errors.New("server not found: gcw2"), "failed_server_gone"},
{"server not found wrapped", errors.New("failed to refresh OAuth token: server not found: myserver"), "failed_server_gone"},
{"server does not use OAuth", errors.New("server does not use OAuth: gcw2"), "failed_server_gone"},
}

for _, tt := range tests {
Expand All @@ -656,3 +660,166 @@ func TestClassifyRefreshError(t *testing.T) {
})
}
}

// Test that calculateBackoff never returns zero or negative for any retry count (overflow protection).
func TestRefreshManager_BackoffOverflowProtection(t *testing.T) {
logger := zaptest.NewLogger(t)
manager := NewRefreshManager(nil, nil, nil, logger)

// Specific boundary cases that triggered the original bug
boundaryTests := []struct {
retryCount int
desc string
}{
{30, "overflow boundary on 64-bit (10s * 2^30 overflows nanoseconds)"},
{63, "1<<63 is min int64, duration becomes 0 or negative"},
{64, "1<<64 wraps to 0 on 64-bit"},
{100, "well beyond overflow"},
{1000, "extreme retry count"},
{23158728, "actual retry count from issue #310"},
}

for _, tt := range boundaryTests {
t.Run(tt.desc, func(t *testing.T) {
backoff := manager.calculateBackoff(tt.retryCount)
assert.Greater(t, backoff, time.Duration(0),
"Backoff must be positive at retryCount=%d", tt.retryCount)
assert.LessOrEqual(t, backoff, MaxRetryBackoff,
"Backoff must not exceed MaxRetryBackoff at retryCount=%d", tt.retryCount)
})
}

// Exhaustive check: no retry count from 0 to 10000 produces zero or negative backoff
for i := 0; i <= 10000; i++ {
backoff := manager.calculateBackoff(i)
if backoff <= 0 || backoff > MaxRetryBackoff {
t.Fatalf("calculateBackoff(%d) = %v, want positive and <= %v", i, backoff, MaxRetryBackoff)
}
}
}

// Test that "server not found" is treated as terminal (no retry).
func TestRefreshManager_ServerNotFoundIsTerminal(t *testing.T) {
logger := zaptest.NewLogger(t)
store := newMockTokenStore()
emitter := &mockEventEmitter{}

manager := NewRefreshManager(store, nil, nil, logger)
runtime := &mockRuntime{
refreshErr: errors.New("failed to refresh OAuth token: server not found: gcw2"),
}
manager.SetRuntime(runtime)
manager.SetEventEmitter(emitter)

ctx := context.Background()
err := manager.Start(ctx)
require.NoError(t, err)
defer manager.Stop()

// Create a schedule and trigger refresh
expiresAt := time.Now().Add(1 * time.Hour)
manager.OnTokenSaved("gcw2", expiresAt)
time.Sleep(50 * time.Millisecond)

// Execute refresh which will fail with "server not found"
manager.executeRefresh("gcw2")
time.Sleep(50 * time.Millisecond)

// Should be in failed state (not retrying)
schedule := manager.GetSchedule("gcw2")
require.NotNil(t, schedule)
assert.Equal(t, RefreshStateFailed, schedule.RefreshState,
"Server not found should immediately fail, not retry")
assert.Equal(t, 1, schedule.RetryCount, "Should have exactly 1 attempt")

// Failure event should be emitted
assert.Equal(t, 1, emitter.GetFailedEvents(),
"Should emit failure event for terminal server-not-found error")
}

// Test that max retry limit stops retries.
func TestRefreshManager_MaxRetryLimit(t *testing.T) {
logger := zaptest.NewLogger(t)
store := newMockTokenStore()
emitter := &mockEventEmitter{}

config := &RefreshManagerConfig{
Threshold: 0.1,
MaxRetries: 5, // Low limit for testing
}

manager := NewRefreshManager(store, nil, config, logger)
runtime := &mockRuntime{
refreshErr: errors.New("some transient error"),
}
manager.SetRuntime(runtime)
manager.SetEventEmitter(emitter)

ctx := context.Background()
err := manager.Start(ctx)
require.NoError(t, err)
defer manager.Stop()

// Create a schedule
expiresAt := time.Now().Add(1 * time.Hour)
manager.OnTokenSaved("test-server", expiresAt)
time.Sleep(50 * time.Millisecond)

// Simulate failures beyond max retries
for i := 0; i < 7; i++ {
manager.executeRefresh("test-server")
time.Sleep(10 * time.Millisecond)
}

// After exceeding max retries, should be in failed state
schedule := manager.GetSchedule("test-server")
require.NotNil(t, schedule)
assert.Equal(t, RefreshStateFailed, schedule.RefreshState,
"Should transition to failed after max retries exceeded")

// Failure event should be emitted
assert.GreaterOrEqual(t, emitter.GetFailedEvents(), 1,
"Should emit failure event when max retries exceeded")
}

// Test that DefaultMaxRetries is now non-zero.
func TestDefaultMaxRetries_IsNonZero(t *testing.T) {
assert.Greater(t, DefaultMaxRetries, 0,
"DefaultMaxRetries must be non-zero to prevent infinite retry loops")
}

// Test that rescheduleAfterDelay enforces minimum delay.
func TestRefreshManager_MinimumDelayEnforced(t *testing.T) {
logger := zaptest.NewLogger(t)
store := newMockTokenStore()

manager := NewRefreshManager(store, nil, nil, logger)
runtime := &mockRuntime{}
manager.SetRuntime(runtime)

ctx := context.Background()
err := manager.Start(ctx)
require.NoError(t, err)
defer manager.Stop()

// Create a schedule first
expiresAt := time.Now().Add(1 * time.Hour)
manager.OnTokenSaved("test-server", expiresAt)
time.Sleep(50 * time.Millisecond)

// Call rescheduleAfterDelay with zero delay
manager.rescheduleAfterDelay("test-server", 0)

schedule := manager.GetSchedule("test-server")
require.NotNil(t, schedule)
assert.GreaterOrEqual(t, schedule.RetryBackoff, MinRefreshInterval,
"Delay should be at least MinRefreshInterval even when 0 is passed")

// Call rescheduleAfterDelay with negative delay
manager.rescheduleAfterDelay("test-server", -5*time.Second)

schedule = manager.GetSchedule("test-server")
require.NotNil(t, schedule)
assert.GreaterOrEqual(t, schedule.RetryBackoff, MinRefreshInterval,
"Delay should be at least MinRefreshInterval even when negative is passed")
}