diff --git a/src/authgate/__init__.py b/src/authgate/__init__.py index 9706f03..ef0ec38 100644 --- a/src/authgate/__init__.py +++ b/src/authgate/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import enum from authgate._version import __version__ @@ -12,7 +13,7 @@ from authgate.credstore import default_token_secure_store from authgate.discovery.async_client import AsyncDiscoveryClient from authgate.discovery.client import DiscoveryClient -from authgate.exceptions import AuthGateError +from authgate.exceptions import AuthFlowError, AuthGateError, NotFoundError, OAuthError from authgate.oauth.async_client import AsyncOAuthClient from authgate.oauth.client import OAuthClient from authgate.oauth.models import Token @@ -63,8 +64,13 @@ def authenticate( try: token = ts.token() return client, token - except Exception: + except (NotFoundError, AuthFlowError): pass + except OAuthError as exc: + # Only swallow errors indicating an invalid/expired refresh token. + # Re-raise unexpected OAuth errors (e.g., server_error) so they surface. + if exc.code not in ("invalid_grant", "invalid_token"): + raise # 5. No valid token — run the appropriate authentication flow if flow_mode == FlowMode.BROWSER: @@ -111,20 +117,37 @@ async def async_authenticate( # 2. Create async OAuth client client = AsyncOAuthClient(client_id, meta.to_endpoints()) - # 3. Check stored token (sync store, run in thread) + # 3. Check stored token and attempt refresh if expired store = default_token_secure_store(service_name, store_path) - ts = TokenSource(OAuthClient(client_id, meta.to_endpoints()), store=store) try: - token = ts.token() - return client, token - except Exception: + stored = await asyncio.to_thread(store.load, client_id) + from authgate.authflow.token_source import credstore_to_oauth, oauth_to_credstore + + if stored.is_valid(): + token = credstore_to_oauth(stored) + return client, token + + # Try refreshing with the stored refresh token + if stored.refresh_token: + try: + token = await client.refresh_token(stored.refresh_token) + await asyncio.to_thread(store.save, client_id, oauth_to_credstore(token, client_id)) + return client, token + except OAuthError as exc: + # Only swallow errors indicating an invalid/expired refresh token. + # Re-raise unexpected OAuth errors (e.g., server_error). + if exc.code not in ("invalid_grant", "invalid_token"): + raise + except NotFoundError: pass # 4. Run device flow (always, since auth code needs sync HTTP server) token = await async_run_device_flow(client, _scopes) # 5. Persist - ts.save_token(token) + from authgate.authflow.token_source import oauth_to_credstore + + await asyncio.to_thread(store.save, client_id, oauth_to_credstore(token, client_id)) return client, token diff --git a/src/authgate/authflow/__init__.py b/src/authgate/authflow/__init__.py index fc2757c..e93ea73 100644 --- a/src/authgate/authflow/__init__.py +++ b/src/authgate/authflow/__init__.py @@ -4,14 +4,16 @@ from authgate.authflow.browser import check_browser_availability, open_browser from authgate.authflow.device import async_run_device_flow, run_device_flow from authgate.authflow.pkce import PKCE, generate_pkce -from authgate.authflow.token_source import TokenSource +from authgate.authflow.token_source import TokenSource, credstore_to_oauth, oauth_to_credstore __all__ = [ "PKCE", "TokenSource", "async_run_device_flow", "check_browser_availability", + "credstore_to_oauth", "generate_pkce", + "oauth_to_credstore", "open_browser", "run_auth_code_flow", "run_device_flow", diff --git a/src/authgate/authflow/authcode.py b/src/authgate/authflow/authcode.py index e879343..096e016 100644 --- a/src/authgate/authflow/authcode.py +++ b/src/authgate/authflow/authcode.py @@ -5,79 +5,86 @@ import os import threading from http.server import BaseHTTPRequestHandler, HTTPServer -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, urlencode, urlparse from authgate.authflow.browser import open_browser from authgate.authflow.pkce import generate_pkce -from authgate.exceptions import OAuthError +from authgate.exceptions import AuthFlowError, OAuthError from authgate.oauth.client import OAuthClient from authgate.oauth.models import Token +_CALLBACK_TIMEOUT = 300.0 # 5 minutes + def _generate_state() -> str: """Generate a cryptographically random state string for CSRF protection.""" return os.urandom(16).hex() -class _CallbackHandler(BaseHTTPRequestHandler): - """HTTP handler for the OAuth callback.""" - - result: dict[str, str] # shared via class attribute set by factory - event: threading.Event - - def do_GET(self) -> None: - parsed = urlparse(self.path) - if parsed.path != "/callback": - self.send_response(404) - self.end_headers() - return - - params = parse_qs(parsed.query) - - # Only process the first callback - if self.event.is_set(): - self.send_response(200) - self.end_headers() - self.wfile.write(b"
State mismatch. You can close this window.
" - ) - return - - code = params.get("code", [""])[0] - if not code: - self.result["error"] = params.get("error", ["no code received"])[0] - self.result["error_description"] = params.get("error_description", [""])[0] - self.event.set() +def _make_callback_handler( + result: dict[str, str], + event: threading.Event, +) -> type[BaseHTTPRequestHandler]: + """Create a handler class with per-flow state to avoid class-attribute sharing.""" + + class _CallbackHandler(BaseHTTPRequestHandler): + """HTTP handler for the OAuth callback.""" + + def do_GET(self) -> None: + parsed = urlparse(self.path) + if parsed.path != "/callback": + self.send_response(404) + self.end_headers() + return + + params = parse_qs(parsed.query) + + # Only process the first callback + if event.is_set(): + self.send_response(200) + self.end_headers() + self.wfile.write(b"State mismatch. You can close this window.
" + ) + return + + code = params.get("code", [""])[0] + if not code: + result["error"] = params.get("error", ["no code received"])[0] + result["error_description"] = params.get("error_description", [""])[0] + event.set() + self.send_response(200) + self.end_headers() + self.wfile.write( + b"You can close this window.
" + ) + return + + result["code"] = code + event.set() self.send_response(200) self.end_headers() self.wfile.write( - b"You can close this window.
" ) - return - - self.result["code"] = code - self.event.set() - self.send_response(200) - self.end_headers() - self.wfile.write( - b"You can close this window.
" - ) - def log_message(self, fmt: str, *args: object) -> None: - pass # Suppress HTTP server logs + def log_message(self, fmt: str, *args: object) -> None: + pass # Suppress HTTP server logs + + return _CallbackHandler def run_auth_code_flow( @@ -90,39 +97,43 @@ def run_auth_code_flow( pkce = generate_pkce() state = _generate_state() + # Set up per-flow shared state + result: dict[str, str] = {"expected_state": state} + event = threading.Event() + handler_cls = _make_callback_handler(result, event) + # Start the callback server - server = HTTPServer(("127.0.0.1", local_port), _CallbackHandler) + server = HTTPServer(("127.0.0.1", local_port), handler_cls) port = server.server_address[1] redirect_uri = f"http://127.0.0.1:{port}/callback" - # Set up shared state via handler class attributes - result: dict[str, str] = {"expected_state": state} - event = threading.Event() - _CallbackHandler.result = result - _CallbackHandler.event = event - server_thread = threading.Thread(target=server.serve_forever, daemon=True) server_thread.start() try: - # Build authorization URL + # Build authorization URL with proper encoding endpoints = client.endpoints - params = ( - f"response_type=code" - f"&client_id={client.client_id}" - f"&redirect_uri={redirect_uri}" - f"&scope={'+'.join(scopes or [])}" - f"&state={state}" - f"&code_challenge={pkce.challenge}" - f"&code_challenge_method={pkce.method}" + query = urlencode( + { + "response_type": "code", + "client_id": client.client_id, + "redirect_uri": redirect_uri, + "scope": " ".join(scopes or []), + "state": state, + "code_challenge": pkce.challenge, + "code_challenge_method": pkce.method, + } ) - auth_url = f"{endpoints.authorize_url}?{params}" + auth_url = f"{endpoints.authorize_url}?{query}" if not open_browser(auth_url): print(f"Open this URL in your browser:\n{auth_url}") - # Wait for callback - event.wait() + # Wait for callback with timeout + if not event.wait(timeout=_CALLBACK_TIMEOUT): + raise AuthFlowError( + f"auth code flow timed out after {_CALLBACK_TIMEOUT:.0f}s waiting for callback" + ) if "error" in result: raise OAuthError( @@ -133,6 +144,7 @@ def run_auth_code_flow( code = result["code"] finally: server.shutdown() + server.server_close() server_thread.join(timeout=5) return client.exchange_auth_code(code, redirect_uri, pkce.verifier) diff --git a/src/authgate/authflow/token_source.py b/src/authgate/authflow/token_source.py index 58ee185..0e0e9a8 100644 --- a/src/authgate/authflow/token_source.py +++ b/src/authgate/authflow/token_source.py @@ -11,7 +11,7 @@ from authgate.oauth.models import Token -def _credstore_to_oauth(t: StoredToken) -> Token: +def credstore_to_oauth(t: StoredToken) -> Token: return Token( access_token=t.access_token, refresh_token=t.refresh_token, @@ -20,7 +20,7 @@ def _credstore_to_oauth(t: StoredToken) -> Token: ) -def _oauth_to_credstore(t: Token, client_id: str) -> StoredToken: +def oauth_to_credstore(t: Token, client_id: str) -> StoredToken: return StoredToken( access_token=t.access_token, refresh_token=t.refresh_token, @@ -62,7 +62,7 @@ def token(self) -> Token: try: stored = self._store.load(self._client.client_id) if stored.is_valid(): - tok = _credstore_to_oauth(stored) + tok = credstore_to_oauth(stored) self._cached = tok return tok except NotFoundError: @@ -107,7 +107,7 @@ def _do_refresh(self) -> Token: try: stored = self._store.load(self._client.client_id) if stored.is_valid(): - tok = _credstore_to_oauth(stored) + tok = credstore_to_oauth(stored) self._cached = tok return tok @@ -133,5 +133,5 @@ def _save_token(self, token: Token) -> None: return self._store.save( self._client.client_id, - _oauth_to_credstore(token, self._client.client_id), + oauth_to_credstore(token, self._client.client_id), ) diff --git a/src/authgate/exceptions.py b/src/authgate/exceptions.py index bdca498..8d1e7b9 100644 --- a/src/authgate/exceptions.py +++ b/src/authgate/exceptions.py @@ -1,5 +1,7 @@ """AuthGate exception hierarchy.""" +from __future__ import annotations + class AuthGateError(Exception): """Base exception for all AuthGate errors.""" diff --git a/src/authgate/middleware/fastapi.py b/src/authgate/middleware/fastapi.py index 3493b3b..7f5afef 100644 --- a/src/authgate/middleware/fastapi.py +++ b/src/authgate/middleware/fastapi.py @@ -2,15 +2,13 @@ from __future__ import annotations -from collections.abc import Callable - from authgate.exceptions import OAuthError from authgate.middleware.core import ValidationMode, extract_bearer_token, validate_token from authgate.middleware.models import TokenInfo from authgate.oauth.client import OAuthClient try: - from fastapi import Depends, HTTPException, Request + from fastapi import HTTPException, Request from fastapi.security import HTTPBearer except ImportError as exc: raise ImportError( @@ -69,27 +67,3 @@ async def __call__(self, request: Request) -> TokenInfo: ) return info - - -def require_scope(*scopes: str) -> Callable[..., object]: - """FastAPI dependency that checks for additional scopes. - - Must be used after BearerAuth. - """ - - _default = Depends() - - async def dependency(info: TokenInfo = _default) -> TokenInfo: - for scope in scopes: - if not info.has_scope(scope): - raise HTTPException( - status_code=403, - detail={ - "error": "insufficient_scope", - "error_description": f"Token does not have required scope: {scope}", - }, - headers={"WWW-Authenticate": 'Bearer error="insufficient_scope"'}, - ) - return info - - return dependency diff --git a/src/authgate/oauth/async_client.py b/src/authgate/oauth/async_client.py index 1e4d16b..da5b2a9 100644 --- a/src/authgate/oauth/async_client.py +++ b/src/authgate/oauth/async_client.py @@ -123,9 +123,15 @@ async def revoke(self, token: str) -> None: """Revoke a token (RFC 7009).""" if not self._endpoints.revocation_url: raise OAuthError("invalid_request", "revocation endpoint not configured") + data: dict[str, str] = { + "token": token, + "client_id": self._client_id, + } + if self._client_secret: + data["client_secret"] = self._client_secret resp = await self._http.post( self._endpoints.revocation_url, - data={"token": token}, + data=data, ) if resp.status_code != 200: raise _parse_error_response(resp) diff --git a/src/authgate/oauth/client.py b/src/authgate/oauth/client.py index eafbf51..7707c66 100644 --- a/src/authgate/oauth/client.py +++ b/src/authgate/oauth/client.py @@ -123,9 +123,15 @@ def revoke(self, token: str) -> None: """Revoke a token (RFC 7009).""" if not self._endpoints.revocation_url: raise OAuthError("invalid_request", "revocation endpoint not configured") + data: dict[str, str] = { + "token": token, + "client_id": self._client_id, + } + if self._client_secret: + data["client_secret"] = self._client_secret resp = self._http.post( self._endpoints.revocation_url, - data={"token": token}, + data=data, ) if resp.status_code != 200: raise _parse_error_response(resp)