Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 182 additions & 5 deletions crates/buzz-agent/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ const BROWSER_AUTH_TIMEOUT: Duration = Duration::from_secs(60);
#[async_trait]
pub trait TokenSource: Send + Sync {
async fn bearer(&self) -> Result<String, AgentError>;

/// Force a fresh bearer after the server rejected the current one (401).
///
/// `rejected` is the exact access token that just got the 401. Unlike
/// [`bearer`](Self::bearer), which trusts the local expiry clock, this is
/// driven by the server's verdict: the cached token looked valid to us
/// (well within its local expiry) but the provider rejected it — clock
/// skew, server-side revocation, or a node that never saw it. The clock
/// therefore can't decide whether to refresh; the caller passes the
/// rejected token so the impl can refresh unless a concurrent caller has
/// *already* replaced it. Implementations must obtain a new token without
/// any interactive step, so a headless harness never hangs. The default
/// returns the existing bearer — correct for sources that can't refresh
/// (a static key); the caller's retry then fails terminally rather than
/// looping.
async fn refresh_now(&self, _rejected: &str) -> Result<String, AgentError> {
self.bearer().await
}
}

/// A token that never changes for the life of the process.
Expand Down Expand Up @@ -219,17 +237,28 @@ impl TokenSource for PkceOAuthTokenSource {
async fn bearer(&self) -> Result<String, AgentError> {
let mut state = self.state.lock().await;

// 1. Cache hit, still fresh.
// 1. In-memory cache hit, still fresh.
if let Some(tok) = state.as_ref() {
if !is_expired(tok) {
return Ok(tok.access_token.clone());
}
}

// 2. Cache hit, expired, but we have a refresh token.
// 2. Re-read disk — another process may have refreshed already.
if let Some(disk_tok) = read_cache(&self.cache_path) {
if !is_expired(&disk_tok) {
let bearer = disk_tok.access_token.clone();
*state = Some(disk_tok);
return Ok(bearer);
}
}

// 3. Try refresh if we have a refresh token. Discover endpoints once
// here — deliberately hoisted above the refresh-token check so the
// browser flow at step 5 (which also needs them) reuses this call.
let endpoints = self.endpoints().await?;
let refresh = state.as_ref().and_then(|t| t.refresh_token.clone());
if let Some(rt) = refresh {
let endpoints = self.endpoints().await?;
match self.refresh(&endpoints, &rt).await {
Ok(fresh) => {
let bearer = fresh.access_token.clone();
Expand All @@ -240,15 +269,79 @@ impl TokenSource for PkceOAuthTokenSource {
tracing::warn!(error = %e, "oauth refresh failed; falling back to browser flow");
}
}

// 4. Re-read disk after refresh failure — another process may have won the race.
if let Some(disk_tok) = read_cache(&self.cache_path) {
if !is_expired(&disk_tok) {
let bearer = disk_tok.access_token.clone();
*state = Some(disk_tok);
return Ok(bearer);
}
}
}

// 3. No usable cache: full browser dance.
let endpoints = self.endpoints().await?;
// 5. No usable cache: full browser dance.
let fresh = browser_pkce_flow(&self.http, &self.cfg, &endpoints).await?;
let bearer = fresh.access_token.clone();
self.save(&mut state, fresh)?;
Ok(bearer)
}

/// Force-refresh after a 401, never touching the browser flow.
///
/// `rejected` is the access token the server just 401'd. Coalescing keys
/// off token *identity*, not the expiry clock: a 401 means the token was
/// rejected while it still looked locally fresh, so `is_expired()` would
/// say "keep it" and no grant would ever run. Instead, under the lock we
/// compare the current cached token to `rejected` — if they differ, a
/// concurrent caller (this process or a sibling) already refreshed, so we
/// return the new token without burning a second grant. If they still
/// match, this is the rejected token and we run the refresh-token grant
/// unconditionally. The whole check→refresh→save runs under one lock hold
/// so concurrent callers serialize. On any failure the refresh token is
/// preserved (never nulled) and the error is terminal `LlmAuth` — no
/// browser, no hang.
async fn refresh_now(&self, rejected: &str) -> Result<String, AgentError> {
let mut state = self.state.lock().await;

// 1. Coalesce by identity: if the cached token (in-memory, then disk)
// is no longer the one the server rejected, someone already
// refreshed it. Return that instead of grabbing another grant.
if let Some(tok) = state.as_ref() {
if tok.access_token != rejected {
return Ok(tok.access_token.clone());
}
}
if let Some(disk_tok) = read_cache(&self.cache_path) {
if disk_tok.access_token != rejected {
let bearer = disk_tok.access_token.clone();
*state = Some(disk_tok);
return Ok(bearer);
}
}

// 2. The cached token is still the rejected one. Run the refresh-token
// grant unconditionally — the expiry clock can't be trusted here, a
// locally-fresh token is exactly what got 401'd.
let refresh = state.as_ref().and_then(|t| t.refresh_token.clone());
let Some(rt) = refresh else {
return Err(AgentError::LlmAuth(
"token rejected and no refresh token available".into(),
));
};
let endpoints = self.endpoints().await?;
match self.refresh(&endpoints, &rt).await {
Ok(fresh) => {
let bearer = fresh.access_token.clone();
self.save(&mut state, fresh)?;
Ok(bearer)
}
// 3. Refresh token is itself dead. Terminal — surfacing LlmAuth
// stops the retry loop instead of falling to the browser flow,
// which would hang a headless harness.
Err(e) => Err(AgentError::LlmAuth(format!("token refresh failed: {e}"))),
}
}
}

// ---- helpers -------------------------------------------------------------
Expand Down Expand Up @@ -549,4 +642,88 @@ mod tests {
let v: Value = serde_json::from_str(r#"{"access_token":""}"#).unwrap();
assert!(token_from_response(&v, None).is_err());
}

#[tokio::test]
async fn test_bearer_reuses_disk_token_after_expiry() {
let dir = tempfile::tempdir().unwrap();
let cfg = PkceOAuthConfig {
discovery_url: "https://example.com/.well-known".into(),
client_id: "test-client".into(),
scopes: vec!["offline_access".into()],
cache_namespace: "test".into(),
cache_dir_override: Some(dir.path().to_path_buf()),
};
let source = PkceOAuthTokenSource::new(cfg).unwrap();

// Expire the in-memory state.
{
let mut state = source.state.lock().await;
*state = Some(CachedToken {
access_token: "stale".into(),
refresh_token: None,
expires_at: Some(0), // long expired
});
}

// Write a valid token to disk (simulating another process refreshing).
let future_exp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 7200;
let fresh_token = CachedToken {
access_token: "fresh-from-disk".into(),
refresh_token: Some("rt".into()),
expires_at: Some(future_exp),
};
let body = serde_json::to_vec_pretty(&fresh_token).unwrap();
fs::write(&source.cache_path, &body).unwrap();

// bearer() should pick up the disk token without any network call.
let result = source.bearer().await.unwrap();
assert_eq!(result, "fresh-from-disk");
}

#[tokio::test]
async fn test_bearer_falls_through_to_browser_when_disk_also_expired() {
let dir = tempfile::tempdir().unwrap();
let cfg = PkceOAuthConfig {
discovery_url: "https://example.com/.well-known".into(),
client_id: "test-client".into(),
scopes: vec!["offline_access".into()],
cache_namespace: "test".into(),
cache_dir_override: Some(dir.path().to_path_buf()),
};
let source = PkceOAuthTokenSource::new(cfg).unwrap();

// Expire the in-memory state.
{
let mut state = source.state.lock().await;
*state = Some(CachedToken {
access_token: "stale".into(),
refresh_token: None,
expires_at: Some(0),
});
}

// Write an expired token to disk too.
let expired_token = CachedToken {
access_token: "also-stale".into(),
refresh_token: None,
expires_at: Some(0),
};
let body = serde_json::to_vec_pretty(&expired_token).unwrap();
fs::write(&source.cache_path, &body).unwrap();

// bearer() should fall through past the disk check.
// It will fail at the endpoints() discovery call since there's no server,
// which proves it didn't short-circuit on the expired disk token.
let result = source.bearer().await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("oauth discovery"),
"expected discovery error, got: {err_msg}"
);
}
}
Loading