Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions src/authgate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import enum

from authgate._version import __version__
Expand All @@ -12,7 +13,7 @@
from authgate.credstore import default_token_secure_store
from authgate.discovery.async_client import AsyncDiscoveryClient
from authgate.discovery.client import DiscoveryClient
from authgate.exceptions import AuthGateError
from authgate.exceptions import AuthFlowError, AuthGateError, NotFoundError, OAuthError
from authgate.oauth.async_client import AsyncOAuthClient
from authgate.oauth.client import OAuthClient
from authgate.oauth.models import Token
Expand Down Expand Up @@ -63,8 +64,13 @@ def authenticate(
try:
token = ts.token()
return client, token
except Exception:
except (NotFoundError, AuthFlowError):
pass
except OAuthError as exc:
# Only swallow errors indicating an invalid/expired refresh token.
# Re-raise unexpected OAuth errors (e.g., server_error) so they surface.
if exc.code not in ("invalid_grant", "invalid_token"):
raise

# 5. No valid token — run the appropriate authentication flow
if flow_mode == FlowMode.BROWSER:
Expand Down Expand Up @@ -111,20 +117,37 @@ async def async_authenticate(
# 2. Create async OAuth client
client = AsyncOAuthClient(client_id, meta.to_endpoints())

# 3. Check stored token (sync store, run in thread)
# 3. Check stored token and attempt refresh if expired
store = default_token_secure_store(service_name, store_path)
ts = TokenSource(OAuthClient(client_id, meta.to_endpoints()), store=store)
try:
token = ts.token()
return client, token
except Exception:
stored = await asyncio.to_thread(store.load, client_id)
from authgate.authflow.token_source import credstore_to_oauth, oauth_to_credstore

if stored.is_valid():
token = credstore_to_oauth(stored)
return client, token

# Try refreshing with the stored refresh token
if stored.refresh_token:
try:
token = await client.refresh_token(stored.refresh_token)
await asyncio.to_thread(store.save, client_id, oauth_to_credstore(token, client_id))
return client, token
except OAuthError as exc:
# Only swallow errors indicating an invalid/expired refresh token.
# Re-raise unexpected OAuth errors (e.g., server_error).
if exc.code not in ("invalid_grant", "invalid_token"):
raise
except NotFoundError:
pass

# 4. Run device flow (always, since auth code needs sync HTTP server)
token = await async_run_device_flow(client, _scopes)

# 5. Persist
ts.save_token(token)
from authgate.authflow.token_source import oauth_to_credstore

await asyncio.to_thread(store.save, client_id, oauth_to_credstore(token, client_id))

return client, token

Expand Down
4 changes: 3 additions & 1 deletion src/authgate/authflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from authgate.authflow.browser import check_browser_availability, open_browser
from authgate.authflow.device import async_run_device_flow, run_device_flow
from authgate.authflow.pkce import PKCE, generate_pkce
from authgate.authflow.token_source import TokenSource
from authgate.authflow.token_source import TokenSource, credstore_to_oauth, oauth_to_credstore

__all__ = [
"PKCE",
"TokenSource",
"async_run_device_flow",
"check_browser_availability",
"credstore_to_oauth",
"generate_pkce",
"oauth_to_credstore",
"open_browser",
"run_auth_code_flow",
"run_device_flow",
Expand Down
160 changes: 86 additions & 74 deletions src/authgate/authflow/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,79 +5,86 @@
import os
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import parse_qs, urlparse
from urllib.parse import parse_qs, urlencode, urlparse

from authgate.authflow.browser import open_browser
from authgate.authflow.pkce import generate_pkce
from authgate.exceptions import OAuthError
from authgate.exceptions import AuthFlowError, OAuthError
from authgate.oauth.client import OAuthClient
from authgate.oauth.models import Token

_CALLBACK_TIMEOUT = 300.0 # 5 minutes


def _generate_state() -> str:
"""Generate a cryptographically random state string for CSRF protection."""
return os.urandom(16).hex()


class _CallbackHandler(BaseHTTPRequestHandler):
"""HTTP handler for the OAuth callback."""

result: dict[str, str] # shared via class attribute set by factory
event: threading.Event

def do_GET(self) -> None:
parsed = urlparse(self.path)
if parsed.path != "/callback":
self.send_response(404)
self.end_headers()
return

params = parse_qs(parsed.query)

# Only process the first callback
if self.event.is_set():
self.send_response(200)
self.end_headers()
self.wfile.write(b"<html><body><h1>Already processed</h1></body></html>")
return

state = params.get("state", [""])[0]
if state != self.result.get("expected_state"):
self.result["error"] = "invalid_state"
self.result["error_description"] = "State parameter mismatch"
self.event.set()
self.send_response(200)
self.end_headers()
self.wfile.write(
b"<html><body><h1>Authentication failed</h1>"
b"<p>State mismatch. You can close this window.</p></body></html>"
)
return

code = params.get("code", [""])[0]
if not code:
self.result["error"] = params.get("error", ["no code received"])[0]
self.result["error_description"] = params.get("error_description", [""])[0]
self.event.set()
def _make_callback_handler(
result: dict[str, str],
event: threading.Event,
) -> type[BaseHTTPRequestHandler]:
"""Create a handler class with per-flow state to avoid class-attribute sharing."""

class _CallbackHandler(BaseHTTPRequestHandler):
"""HTTP handler for the OAuth callback."""

def do_GET(self) -> None:
parsed = urlparse(self.path)
if parsed.path != "/callback":
self.send_response(404)
self.end_headers()
return

params = parse_qs(parsed.query)

# Only process the first callback
if event.is_set():
self.send_response(200)
self.end_headers()
self.wfile.write(b"<html><body><h1>Already processed</h1></body></html>")
return

state = params.get("state", [""])[0]
if state != result.get("expected_state"):
result["error"] = "invalid_state"
result["error_description"] = "State parameter mismatch"
event.set()
self.send_response(200)
self.end_headers()
self.wfile.write(
b"<html><body><h1>Authentication failed</h1>"
b"<p>State mismatch. You can close this window.</p></body></html>"
)
return

code = params.get("code", [""])[0]
if not code:
result["error"] = params.get("error", ["no code received"])[0]
result["error_description"] = params.get("error_description", [""])[0]
event.set()
self.send_response(200)
self.end_headers()
self.wfile.write(
b"<html><body><h1>Authentication failed</h1>"
b"<p>You can close this window.</p></body></html>"
)
return

result["code"] = code
event.set()
self.send_response(200)
self.end_headers()
self.wfile.write(
b"<html><body><h1>Authentication failed</h1>"
b"<html><body><h1>Authentication successful</h1>"
b"<p>You can close this window.</p></body></html>"
)
return

self.result["code"] = code
self.event.set()
self.send_response(200)
self.end_headers()
self.wfile.write(
b"<html><body><h1>Authentication successful</h1>"
b"<p>You can close this window.</p></body></html>"
)

def log_message(self, fmt: str, *args: object) -> None:
pass # Suppress HTTP server logs
def log_message(self, fmt: str, *args: object) -> None:
pass # Suppress HTTP server logs

return _CallbackHandler


def run_auth_code_flow(
Expand All @@ -90,39 +97,43 @@ def run_auth_code_flow(
pkce = generate_pkce()
state = _generate_state()

# Set up per-flow shared state
result: dict[str, str] = {"expected_state": state}
event = threading.Event()
handler_cls = _make_callback_handler(result, event)

# Start the callback server
server = HTTPServer(("127.0.0.1", local_port), _CallbackHandler)
server = HTTPServer(("127.0.0.1", local_port), handler_cls)
port = server.server_address[1]
redirect_uri = f"http://127.0.0.1:{port}/callback"

# Set up shared state via handler class attributes
result: dict[str, str] = {"expected_state": state}
event = threading.Event()
_CallbackHandler.result = result
_CallbackHandler.event = event

server_thread = threading.Thread(target=server.serve_forever, daemon=True)
server_thread.start()

try:
# Build authorization URL
# Build authorization URL with proper encoding
endpoints = client.endpoints
params = (
f"response_type=code"
f"&client_id={client.client_id}"
f"&redirect_uri={redirect_uri}"
f"&scope={'+'.join(scopes or [])}"
f"&state={state}"
f"&code_challenge={pkce.challenge}"
f"&code_challenge_method={pkce.method}"
query = urlencode(
{
"response_type": "code",
"client_id": client.client_id,
"redirect_uri": redirect_uri,
"scope": " ".join(scopes or []),
"state": state,
"code_challenge": pkce.challenge,
"code_challenge_method": pkce.method,
}
)
auth_url = f"{endpoints.authorize_url}?{params}"
auth_url = f"{endpoints.authorize_url}?{query}"

if not open_browser(auth_url):
print(f"Open this URL in your browser:\n{auth_url}")

# Wait for callback
event.wait()
# Wait for callback with timeout
if not event.wait(timeout=_CALLBACK_TIMEOUT):
raise AuthFlowError(
f"auth code flow timed out after {_CALLBACK_TIMEOUT:.0f}s waiting for callback"
)

if "error" in result:
raise OAuthError(
Expand All @@ -133,6 +144,7 @@ def run_auth_code_flow(
code = result["code"]
finally:
server.shutdown()
server.server_close()
server_thread.join(timeout=5)

return client.exchange_auth_code(code, redirect_uri, pkce.verifier)
10 changes: 5 additions & 5 deletions src/authgate/authflow/token_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from authgate.oauth.models import Token


def _credstore_to_oauth(t: StoredToken) -> Token:
def credstore_to_oauth(t: StoredToken) -> Token:
return Token(
access_token=t.access_token,
refresh_token=t.refresh_token,
Expand All @@ -20,7 +20,7 @@ def _credstore_to_oauth(t: StoredToken) -> Token:
)


def _oauth_to_credstore(t: Token, client_id: str) -> StoredToken:
def oauth_to_credstore(t: Token, client_id: str) -> StoredToken:
return StoredToken(
access_token=t.access_token,
refresh_token=t.refresh_token,
Expand Down Expand Up @@ -62,7 +62,7 @@ def token(self) -> Token:
try:
stored = self._store.load(self._client.client_id)
if stored.is_valid():
tok = _credstore_to_oauth(stored)
tok = credstore_to_oauth(stored)
self._cached = tok
return tok
except NotFoundError:
Expand Down Expand Up @@ -107,7 +107,7 @@ def _do_refresh(self) -> Token:
try:
stored = self._store.load(self._client.client_id)
if stored.is_valid():
tok = _credstore_to_oauth(stored)
tok = credstore_to_oauth(stored)
self._cached = tok
return tok

Expand All @@ -133,5 +133,5 @@ def _save_token(self, token: Token) -> None:
return
self._store.save(
self._client.client_id,
_oauth_to_credstore(token, self._client.client_id),
oauth_to_credstore(token, self._client.client_id),
)
2 changes: 2 additions & 0 deletions src/authgate/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AuthGate exception hierarchy."""

from __future__ import annotations


class AuthGateError(Exception):
"""Base exception for all AuthGate errors."""
Expand Down
Loading
Loading