diff --git a/crates/buzz-agent/src/auth.rs b/crates/buzz-agent/src/auth.rs index 53b753672..5b439aa78 100644 --- a/crates/buzz-agent/src/auth.rs +++ b/crates/buzz-agent/src/auth.rs @@ -43,6 +43,24 @@ const BROWSER_AUTH_TIMEOUT: Duration = Duration::from_secs(60); #[async_trait] pub trait TokenSource: Send + Sync { async fn bearer(&self) -> Result; + + /// Force a fresh bearer after the server rejected the current one (401). + /// + /// `rejected` is the exact access token that just got the 401. Unlike + /// [`bearer`](Self::bearer), which trusts the local expiry clock, this is + /// driven by the server's verdict: the cached token looked valid to us + /// (well within its local expiry) but the provider rejected it — clock + /// skew, server-side revocation, or a node that never saw it. The clock + /// therefore can't decide whether to refresh; the caller passes the + /// rejected token so the impl can refresh unless a concurrent caller has + /// *already* replaced it. Implementations must obtain a new token without + /// any interactive step, so a headless harness never hangs. The default + /// returns the existing bearer — correct for sources that can't refresh + /// (a static key); the caller's retry then fails terminally rather than + /// looping. + async fn refresh_now(&self, _rejected: &str) -> Result { + self.bearer().await + } } /// A token that never changes for the life of the process. @@ -219,17 +237,28 @@ impl TokenSource for PkceOAuthTokenSource { async fn bearer(&self) -> Result { let mut state = self.state.lock().await; - // 1. Cache hit, still fresh. + // 1. In-memory cache hit, still fresh. if let Some(tok) = state.as_ref() { if !is_expired(tok) { return Ok(tok.access_token.clone()); } } - // 2. Cache hit, expired, but we have a refresh token. + // 2. Re-read disk — another process may have refreshed already. + if let Some(disk_tok) = read_cache(&self.cache_path) { + if !is_expired(&disk_tok) { + let bearer = disk_tok.access_token.clone(); + *state = Some(disk_tok); + return Ok(bearer); + } + } + + // 3. Try refresh if we have a refresh token. Discover endpoints once + // here — deliberately hoisted above the refresh-token check so the + // browser flow at step 5 (which also needs them) reuses this call. + let endpoints = self.endpoints().await?; let refresh = state.as_ref().and_then(|t| t.refresh_token.clone()); if let Some(rt) = refresh { - let endpoints = self.endpoints().await?; match self.refresh(&endpoints, &rt).await { Ok(fresh) => { let bearer = fresh.access_token.clone(); @@ -240,15 +269,79 @@ impl TokenSource for PkceOAuthTokenSource { tracing::warn!(error = %e, "oauth refresh failed; falling back to browser flow"); } } + + // 4. Re-read disk after refresh failure — another process may have won the race. + if let Some(disk_tok) = read_cache(&self.cache_path) { + if !is_expired(&disk_tok) { + let bearer = disk_tok.access_token.clone(); + *state = Some(disk_tok); + return Ok(bearer); + } + } } - // 3. No usable cache: full browser dance. - let endpoints = self.endpoints().await?; + // 5. No usable cache: full browser dance. let fresh = browser_pkce_flow(&self.http, &self.cfg, &endpoints).await?; let bearer = fresh.access_token.clone(); self.save(&mut state, fresh)?; Ok(bearer) } + + /// Force-refresh after a 401, never touching the browser flow. + /// + /// `rejected` is the access token the server just 401'd. Coalescing keys + /// off token *identity*, not the expiry clock: a 401 means the token was + /// rejected while it still looked locally fresh, so `is_expired()` would + /// say "keep it" and no grant would ever run. Instead, under the lock we + /// compare the current cached token to `rejected` — if they differ, a + /// concurrent caller (this process or a sibling) already refreshed, so we + /// return the new token without burning a second grant. If they still + /// match, this is the rejected token and we run the refresh-token grant + /// unconditionally. The whole check→refresh→save runs under one lock hold + /// so concurrent callers serialize. On any failure the refresh token is + /// preserved (never nulled) and the error is terminal `LlmAuth` — no + /// browser, no hang. + async fn refresh_now(&self, rejected: &str) -> Result { + let mut state = self.state.lock().await; + + // 1. Coalesce by identity: if the cached token (in-memory, then disk) + // is no longer the one the server rejected, someone already + // refreshed it. Return that instead of grabbing another grant. + if let Some(tok) = state.as_ref() { + if tok.access_token != rejected { + return Ok(tok.access_token.clone()); + } + } + if let Some(disk_tok) = read_cache(&self.cache_path) { + if disk_tok.access_token != rejected { + let bearer = disk_tok.access_token.clone(); + *state = Some(disk_tok); + return Ok(bearer); + } + } + + // 2. The cached token is still the rejected one. Run the refresh-token + // grant unconditionally — the expiry clock can't be trusted here, a + // locally-fresh token is exactly what got 401'd. + let refresh = state.as_ref().and_then(|t| t.refresh_token.clone()); + let Some(rt) = refresh else { + return Err(AgentError::LlmAuth( + "token rejected and no refresh token available".into(), + )); + }; + let endpoints = self.endpoints().await?; + match self.refresh(&endpoints, &rt).await { + Ok(fresh) => { + let bearer = fresh.access_token.clone(); + self.save(&mut state, fresh)?; + Ok(bearer) + } + // 3. Refresh token is itself dead. Terminal — surfacing LlmAuth + // stops the retry loop instead of falling to the browser flow, + // which would hang a headless harness. + Err(e) => Err(AgentError::LlmAuth(format!("token refresh failed: {e}"))), + } + } } // ---- helpers ------------------------------------------------------------- @@ -549,4 +642,88 @@ mod tests { let v: Value = serde_json::from_str(r#"{"access_token":""}"#).unwrap(); assert!(token_from_response(&v, None).is_err()); } + + #[tokio::test] + async fn test_bearer_reuses_disk_token_after_expiry() { + let dir = tempfile::tempdir().unwrap(); + let cfg = PkceOAuthConfig { + discovery_url: "https://example.com/.well-known".into(), + client_id: "test-client".into(), + scopes: vec!["offline_access".into()], + cache_namespace: "test".into(), + cache_dir_override: Some(dir.path().to_path_buf()), + }; + let source = PkceOAuthTokenSource::new(cfg).unwrap(); + + // Expire the in-memory state. + { + let mut state = source.state.lock().await; + *state = Some(CachedToken { + access_token: "stale".into(), + refresh_token: None, + expires_at: Some(0), // long expired + }); + } + + // Write a valid token to disk (simulating another process refreshing). + let future_exp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 7200; + let fresh_token = CachedToken { + access_token: "fresh-from-disk".into(), + refresh_token: Some("rt".into()), + expires_at: Some(future_exp), + }; + let body = serde_json::to_vec_pretty(&fresh_token).unwrap(); + fs::write(&source.cache_path, &body).unwrap(); + + // bearer() should pick up the disk token without any network call. + let result = source.bearer().await.unwrap(); + assert_eq!(result, "fresh-from-disk"); + } + + #[tokio::test] + async fn test_bearer_falls_through_to_browser_when_disk_also_expired() { + let dir = tempfile::tempdir().unwrap(); + let cfg = PkceOAuthConfig { + discovery_url: "https://example.com/.well-known".into(), + client_id: "test-client".into(), + scopes: vec!["offline_access".into()], + cache_namespace: "test".into(), + cache_dir_override: Some(dir.path().to_path_buf()), + }; + let source = PkceOAuthTokenSource::new(cfg).unwrap(); + + // Expire the in-memory state. + { + let mut state = source.state.lock().await; + *state = Some(CachedToken { + access_token: "stale".into(), + refresh_token: None, + expires_at: Some(0), + }); + } + + // Write an expired token to disk too. + let expired_token = CachedToken { + access_token: "also-stale".into(), + refresh_token: None, + expires_at: Some(0), + }; + let body = serde_json::to_vec_pretty(&expired_token).unwrap(); + fs::write(&source.cache_path, &body).unwrap(); + + // bearer() should fall through past the disk check. + // It will fail at the endpoints() discovery call since there's no server, + // which proves it didn't short-circuit on the expired disk token. + let result = source.bearer().await; + assert!(result.is_err()); + let err_msg = format!("{}", result.unwrap_err()); + assert!( + err_msg.contains("oauth discovery"), + "expected discovery error, got: {err_msg}" + ); + } } diff --git a/crates/buzz-agent/src/llm.rs b/crates/buzz-agent/src/llm.rs index e1d529511..34427414b 100644 --- a/crates/buzz-agent/src/llm.rs +++ b/crates/buzz-agent/src/llm.rs @@ -187,7 +187,6 @@ impl Llm { path: &str, body: &Value, ) -> Result { - let bearer = self.auth.bearer().await?; let (url, body_owned); let body_ref: &Value = match cfg.provider { Provider::Databricks => { @@ -204,7 +203,25 @@ impl Llm { body } }; - post(&self.http, &url, body_ref, |r| r.bearer_auth(&bearer)).await + + // A 401 or 403 can mean the local expiry clock disagreed with the + // server (skew, revocation, a node that never saw the token). On the + // first such rejection, force a refresh keyed off the rejected bearer + // and retry once. The guard is local to this call so an earlier turn's + // rejection can never suppress a later turn's legitimate retry. Both + // statuses map to `LlmAuth` in `post`: a 403 is indistinguishable from + // an expired-token 403 here, so we refresh once and let it propagate. + let mut bearer = self.auth.bearer().await?; + let mut refreshed = false; + loop { + match post(&self.http, &url, body_ref, |r| r.bearer_auth(&bearer)).await { + Err(AgentError::LlmAuth(_)) if !refreshed => { + refreshed = true; + bearer = self.auth.refresh_now(&bearer).await?; + } + result => return result, + } + } } /// If `err` names `/v1/responses` / "use the Responses API", latch a @@ -803,6 +820,11 @@ where } }; let status = resp.status(); + // Both 401 and 403 are treated as refreshable: a 403 can mean an + // expired or revoked token, not just a pure authorization verdict, and + // the two are indistinguishable at the HTTP-status layer. The caller's + // retry loop keys off `LlmAuth` and refreshes once; the per-call guard + // bounds a pure-authz 403 to one wasted refresh before it propagates. if status == 401 || status == 403 { return Err(AgentError::LlmAuth(read_error_body(resp).await)); } @@ -1461,4 +1483,211 @@ mod tests { let v = serde_json::json!({"usage": {"output_tokens": 5}}); assert_eq!(sum_usage(&v, &["input_tokens", "prompt_tokens"]), None); } + + // ── one-shot refresh-on-401 in post_openai ───────────────────────── + + /// A token source whose `bearer()` always hands back the same stale + /// token and whose `refresh_now()` mints a distinct fresh one, counting + /// each refresh. Lets a test assert exactly how many forced refreshes a + /// `post_openai` call provoked. + struct CountingAuth { + refreshes: std::sync::atomic::AtomicU32, + } + + #[async_trait::async_trait] + impl TokenSource for CountingAuth { + async fn bearer(&self) -> Result { + Ok("stale".into()) + } + async fn refresh_now(&self, _rejected: &str) -> Result { + self.refreshes + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Ok("fresh".into()) + } + } + + /// Stub that answers `reject_status` to any request carrying `Bearer + /// stale` and 200 to `Bearer fresh`. When `always_reject` is set it rejects + /// unconditionally, simulating a token the refresh can never satisfy. + /// Counts requests so a test can assert "one retry, not a loop". + async fn spawn_auth_stub( + always_reject: std::sync::Arc, + reject_status: u16, + ) -> String { + use std::sync::atomic::Ordering; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let reject_line = match reject_status { + 401 => "401 Unauthorized", + 403 => "403 Forbidden", + other => panic!("unsupported reject_status {other}"), + }; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let url = format!("http://{}", listener.local_addr().unwrap()); + tokio::spawn(async move { + loop { + let (mut sock, _) = match listener.accept().await { + Ok(p) => p, + Err(_) => return, + }; + let always_reject = always_reject.clone(); + tokio::spawn(async move { + let mut buf = Vec::new(); + let mut tmp = [0u8; 4096]; + while !buf.windows(4).any(|w| w == b"\r\n\r\n") { + match sock.read(&mut tmp).await { + Ok(0) | Err(_) => return, + Ok(k) => buf.extend_from_slice(&tmp[..k]), + } + } + let head = String::from_utf8_lossy(&buf).to_ascii_lowercase(); + let stale = head.contains("authorization: bearer stale"); + let resp = if always_reject.load(Ordering::SeqCst) || stale { + format!( + "HTTP/1.1 {reject_line}\r\nContent-Length: 11\r\n\ + Connection: close\r\n\r\ntoken stale" + ) + } else { + let body = "{\"ok\":true}"; + format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\ + Content-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body, + ) + }; + let _ = sock.write_all(resp.as_bytes()).await; + let _ = sock.shutdown().await; + }); + } + }); + url + } + + fn llm_with(auth: Arc) -> Llm { + Llm { + http: Client::builder() + .timeout(Duration::from_secs(5)) + .build() + .unwrap(), + auto_upgraded: std::sync::atomic::AtomicBool::new(false), + auth, + } + } + + /// A single 401 forces exactly one refresh, the retry with the fresh + /// token succeeds, and a *later* call gets its own refresh — proving the + /// one-shot guard is per-call, not stored on the source. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn post_openai_refreshes_once_per_call_on_401() { + use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; + + let always_401 = Arc::new(AtomicBool::new(false)); + let base = spawn_auth_stub(always_401, 401).await; + let auth = Arc::new(CountingAuth { + refreshes: AtomicU32::new(0), + }); + let llm = llm_with(auth.clone()); + let mut c = cfg(Provider::OpenAi); + c.base_url = base; + + let out = llm + .post_openai(&c, "/v1/x", &json!({})) + .await + .expect("retry with fresh token should succeed"); + assert_eq!(out, json!({ "ok": true })); + assert_eq!(auth.refreshes.load(Ordering::SeqCst), 1, "one refresh"); + + // Second call's 401 must trigger its own refresh — the guard cannot + // be a stored flag that an earlier turn already tripped. + let out2 = llm.post_openai(&c, "/v1/x", &json!({})).await.unwrap(); + assert_eq!(out2, json!({ "ok": true })); + assert_eq!( + auth.refreshes.load(Ordering::SeqCst), + 2, + "later call gets its own retry" + ); + } + + /// A persistent 401 (even the refreshed token is rejected) propagates as + /// `LlmAuth` after exactly one refresh — no infinite loop. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn post_openai_persistent_401_propagates_after_one_retry() { + use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; + + let always_401 = Arc::new(AtomicBool::new(true)); + let base = spawn_auth_stub(always_401, 401).await; + let auth = Arc::new(CountingAuth { + refreshes: AtomicU32::new(0), + }); + let llm = llm_with(auth.clone()); + let mut c = cfg(Provider::OpenAi); + c.base_url = base; + + let err = llm.post_openai(&c, "/v1/x", &json!({})).await.unwrap_err(); + assert!(matches!(err, AgentError::LlmAuth(_)), "got {err:?}"); + assert_eq!( + auth.refreshes.load(Ordering::SeqCst), + 1, + "exactly one refresh, then propagate" + ); + } + + /// A 403 is treated as refreshable: a persistent 403 forces exactly one + /// refresh-and-retry, then propagates as `LlmAuth`. Proves a revoked-token + /// 403 takes the same recovery path as a 401, bounded by the per-call guard. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn post_openai_persistent_403_propagates_after_one_retry() { + use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; + + let always_403 = Arc::new(AtomicBool::new(true)); + let base = spawn_auth_stub(always_403, 403).await; + let auth = Arc::new(CountingAuth { + refreshes: AtomicU32::new(0), + }); + let llm = llm_with(auth.clone()); + let mut c = cfg(Provider::OpenAi); + c.base_url = base; + + let err = llm.post_openai(&c, "/v1/x", &json!({})).await.unwrap_err(); + assert!(matches!(err, AgentError::LlmAuth(_)), "got {err:?}"); + assert_eq!( + auth.refreshes.load(Ordering::SeqCst), + 1, + "403 refreshes exactly once, then propagates" + ); + } + + /// A recoverable 403 (stale token 403s, fresh token 200s) forces exactly + /// one refresh and the retry succeeds — proving a 403 enters the refresh + /// path and a refreshed token clears it, the stale-token-403 recovery case. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn post_openai_refreshes_once_on_403() { + use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; + + let always_403 = Arc::new(AtomicBool::new(false)); + let base = spawn_auth_stub(always_403, 403).await; + let auth = Arc::new(CountingAuth { + refreshes: AtomicU32::new(0), + }); + let llm = llm_with(auth.clone()); + let mut c = cfg(Provider::OpenAi); + c.base_url = base; + + let out = llm + .post_openai(&c, "/v1/x", &json!({})) + .await + .expect("retry with fresh token should clear the 403"); + assert_eq!(out, json!({ "ok": true })); + assert_eq!(auth.refreshes.load(Ordering::SeqCst), 1, "one refresh"); + } + + /// The default `refresh_now()` on a static source returns the static + /// token unchanged — a key that can't refresh still answers harmlessly. + #[tokio::test] + async fn static_token_source_refresh_now_returns_static_token() { + let src = StaticTokenSource::new("static-key"); + assert_eq!(src.refresh_now("rejected").await.unwrap(), "static-key"); + } } diff --git a/crates/buzz-agent/tests/databricks_oauth.rs b/crates/buzz-agent/tests/databricks_oauth.rs index 25f1e120b..1ee8039a5 100644 --- a/crates/buzz-agent/tests/databricks_oauth.rs +++ b/crates/buzz-agent/tests/databricks_oauth.rs @@ -208,6 +208,143 @@ async fn refreshed_token_is_persisted_to_disk() { assert!(on_disk["expires_at"].is_u64()); } +#[tokio::test] +async fn refresh_now_runs_grant_on_unexpired_rejected_token() { + let tmp = TempDir::new().unwrap(); + + let (base, refresh_counter) = spawn_oidc().await; + let cfg = PkceOAuthConfig { + discovery_url: format!("{base}/.well-known/oauth-authorization-server"), + client_id: "test-client".into(), + scopes: vec!["a".into()], + cache_namespace: "databricks".into(), + cache_dir_override: Some(tmp.path().to_path_buf()), + }; + + // The exact 401 case this whole change exists to fix: a token that is + // still locally *unexpired* but the server rejected it (skew, revocation, + // a node that never saw it). is_expired() says "keep it", so a clock-based + // gate would no-op and the agent would die. refresh_now() must instead key + // off identity — the cached token equals the rejected one — and run the + // grant anyway. The stub never serves a browser flow, so a fresh token + // here proves the refresh-token grant ran, not the interactive path. + let future = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600; + let path = cache_path_for(tmp.path(), &cfg); + seed_cache( + &path, + json!({ + "access_token": "rejected", + "refresh_token": "valid-refresh", + "expires_at": future, + }), + ); + + let src = PkceOAuthTokenSource::new(cfg).unwrap(); + let bearer = src.refresh_now("rejected").await.unwrap(); + assert_eq!(bearer, "fresh-token-1", "grant ran despite local freshness"); + assert_eq!(refresh_counter.load(Ordering::SeqCst), 1, "grant ran once"); + + // The refresh token was preserved (rotated, not discarded): the saved + // token still carries one, so a future 401 can refresh again instead of + // falling to the browser flow. This is the property defect #1 broke. + let on_disk: serde_json::Value = + serde_json::from_slice(&std::fs::read(&path).unwrap()).unwrap(); + assert_eq!(on_disk["access_token"], "fresh-token-1"); + assert_eq!(on_disk["refresh_token"], "rotated-refresh"); +} + +#[tokio::test] +async fn refresh_now_coalesces_when_another_caller_already_refreshed() { + let tmp = TempDir::new().unwrap(); + + let (base, refresh_counter) = spawn_oidc().await; + let cfg = PkceOAuthConfig { + discovery_url: format!("{base}/.well-known/oauth-authorization-server"), + client_id: "test-client".into(), + scopes: vec!["a".into()], + cache_namespace: "databricks".into(), + cache_dir_override: Some(tmp.path().to_path_buf()), + }; + + // A concurrent caller already replaced the rejected token: the cached + // token differs from the one we hold. Coalesce by identity — return the + // new token without burning a second grant, so N concurrent 401s on the + // same stale token collapse onto one refresh. Note the cached token is + // *unexpired* here too, so this proves coalescing keys off identity, not + // the clock (which agrees with both the old and new token). + let future = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600; + let path = cache_path_for(tmp.path(), &cfg); + seed_cache( + &path, + json!({ + "access_token": "already-refreshed", + "refresh_token": "rt", + "expires_at": future, + }), + ); + + let src = PkceOAuthTokenSource::new(cfg).unwrap(); + let bearer = src.refresh_now("the-rejected-one").await.unwrap(); + assert_eq!(bearer, "already-refreshed"); + assert_eq!( + refresh_counter.load(Ordering::SeqCst), + 0, + "no grant when a sibling already refreshed the rejected token" + ); +} + +#[tokio::test] +async fn refresh_now_without_refresh_token_is_terminal() { + let tmp = TempDir::new().unwrap(); + + let (base, refresh_counter) = spawn_oidc().await; + let cfg = PkceOAuthConfig { + discovery_url: format!("{base}/.well-known/oauth-authorization-server"), + client_id: "test-client".into(), + scopes: vec!["a".into()], + cache_namespace: "databricks".into(), + cache_dir_override: Some(tmp.path().to_path_buf()), + }; + + // The rejected token is still the cached one and there's no refresh token + // to fall back on. refresh_now() must fail terminally (LlmAuth) rather + // than open a browser — the headless hang this whole change exists to + // prevent. + let path = cache_path_for(tmp.path(), &cfg); + seed_cache( + &path, + json!({ + "access_token": "rejected", + "refresh_token": serde_json::Value::Null, + "expires_at": 1u64, + }), + ); + + let src = PkceOAuthTokenSource::new(cfg).unwrap(); + let err = src.refresh_now("rejected").await.unwrap_err(); + // `types::AgentError` isn't a public path; match on its Display, which + // prefixes `LlmAuth` variants with "llm auth:". A terminal LlmAuth (not + // a browser hang) is the whole point of this path. + let msg = err.to_string(); + assert!( + msg.starts_with("llm auth:"), + "expected terminal LlmAuth, got: {msg}" + ); + assert_eq!( + refresh_counter.load(Ordering::SeqCst), + 0, + "no grant attempted" + ); +} + // ──────────────────────────────────────────────────────────────────────────── // ACP-level envelope regression test. //