Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1eb4718
feat: add option to open OAuth login on same tab
thepatrickchin Mar 24, 2026
248b47c
feat: validate return URL origin against CORS allowlist
thepatrickchin Mar 24, 2026
f71727f
fix: escape return_url JSON for HTML script context
thepatrickchin Mar 24, 2026
576d555
fix: display auth cancellation message when user clicks back
thepatrickchin Mar 24, 2026
4a22749
fix: patch example OAuth server to support permission decline
thepatrickchin Mar 23, 2026
a448a56
feat: support denied consent flow
thepatrickchin Mar 24, 2026
ef3fc66
feat: display cancellation reason and update tests
thepatrickchin Mar 31, 2026
8356430
docs: add description of use_popup_auth option
thepatrickchin Mar 31, 2026
33db54c
chore: rename occurrences of "same-page" to "redirect"
thepatrickchin Apr 2, 2026
d1d8481
refactor: rename use_popup_auth to use_redirect_auth and negate logic
thepatrickchin Apr 2, 2026
ce7ec06
fix: implement CodeRabbit suggestions
thepatrickchin Apr 2, 2026
c792ba3
fix: implement CodeRabbit suggestions
thepatrickchin Apr 2, 2026
1d8b5e9
fix: implement CodeRabbit suggestions
thepatrickchin Apr 2, 2026
b1a5f6a
feat: enable pre-authorization before prompt is submitted
thepatrickchin Mar 31, 2026
9f99cdb
fix: prevent auth redirect loop when user cancels or declines consent
thepatrickchin Mar 31, 2026
23af716
feat: add pre_authenticate option for authorization providers
thepatrickchin Mar 31, 2026
0effe17
docs: add description of pre_authenticate option
thepatrickchin Mar 31, 2026
61671c5
refactor: rename pre_authenticate option to use_eager_auth
thepatrickchin Apr 2, 2026
b768527
wip: optional distributed cache for token store
thepatrickchin Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/components/auth/api-authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ authentication:
| `redirect_uri` | The redirect URI for OAuth 2.0 authentication. Must match the registered redirect URI with the OAuth provider.|
| `scopes` | List of permissions to the API provider (e.g., `read`, `write`). |
| `use_pkce` | Whether to use PKCE (Proof Key for Code Exchange) in the OAuth 2.0 flow, defaults to `False` |
| `use_eager_auth` | Whether to trigger authentication at WebSocket connection time before the workflow requires credentials, defaults to `False`. When enabled, tokens are cached for the session to avoid re-authentication on reconnect. |
| `use_redirect_auth` | Whether to use a redirect-based flow or open the OAuth consent page in a popup window, defaults to `False` (popup) |
| `authorization_kwargs` | Additional keyword arguments to include in the authorization request. |


Expand Down
6 changes: 6 additions & 0 deletions examples/front_ends/simple_auth/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ RUN apt-get update && apt-get install -y \
# Clone the OAuth2 server example
RUN git clone https://github.com/authlib/example-oauth2-server.git oauth2-server

# Apply patches: add an explicit Cancel button to the authorize route and
# template so that declining the OAuth2 consent redirects back with
# error=access_denied instead of leaving the client app waiting indefinitely.
COPY patches/oauth2-server.patch /tmp/oauth2-server.patch
RUN patch -p1 -d /app/oauth2-server < /tmp/oauth2-server.patch

# Change to the OAuth2 server directory
WORKDIR /app/oauth2-server

Expand Down
30 changes: 30 additions & 0 deletions examples/front_ends/simple_auth/patches/oauth2-server.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
--- a/website/routes.py
+++ b/website/routes.py
@@ -105,7 +105,7 @@
if not user and "username" in request.form:
username = request.form.get("username")
user = User.query.filter_by(username=username).first()
- if request.form["confirm"]:
+ if request.form.get("confirm") == "yes":
grant_user = user
else:
grant_user = None
--- a/website/templates/authorize.html
+++ b/website/templates/authorize.html
@@ -9,14 +9,11 @@
<form action="" method="post">
- <label>
- <input type="checkbox" name="confirm">
- <span>Consent?</span>
- </label>
{% if not user %}
<p>You haven't logged in. Log in with:</p>
<div>
<input type="text" name="username">
</div>
{% endif %}
<br>
- <button>Submit</button>
+ <button type="submit" name="confirm" value="yes">Authorize</button>
+ <button type="submit" name="confirm" value="no">Cancel</button>
</form>
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ authentication:
client_id: ${NAT_OAUTH_CLIENT_ID}
client_secret: ${NAT_OAUTH_CLIENT_SECRET}
use_pkce: false
use_eager_auth: false
use_redirect_auth: false

workflow:
_type: react_agent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ class OAuth2AuthCodeFlowProviderConfig(AuthProviderBaseConfig, name="oauth2_auth
use_pkce: bool = Field(default=False,
description="Whether to use PKCE (Proof Key for Code Exchange) in the OAuth 2.0 flow.")

use_redirect_auth: bool = Field(
default=False,
description=("When False (default), the OAuth login page opens in a popup window and the originating page "
"remains open. When True, the browser navigates to the OAuth login page directly and is "
"redirected back after authentication completes."))

use_eager_auth: bool = Field(
default=False,
description=("When False (default), authentication is deferred until the workflow first requires "
"credentials. When True, authentication is triggered at WebSocket connection time."))

authorization_kwargs: dict[str, str] | None = Field(description=("Additional keyword arguments for the "
"authorization request."),
default=None)
3 changes: 3 additions & 0 deletions packages/nvidia_nat_core/src/nat/data_models/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class _HumanPromptOAuthConsent(HumanPromptBase):
the consent flow.
"""
input_type: typing.Literal[HumanPromptModelType.OAUTH_CONSENT] = HumanPromptModelType.OAUTH_CONSENT
use_redirect: bool = Field(default=False,
description="When False the UI should open the OAuth URL in a popup window. "
"When True the UI should navigate the current tab to the OAuth URL.")


class HumanPromptBinary(HumanPromptBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import time
from abc import ABC
from abc import abstractmethod

from nat.data_models.authentication import AuthenticatedContext
from nat.data_models.object_store import NoSuchKeyError
from nat.object_store.interfaces import ObjectStore
from nat.object_store.models import ObjectStoreItem

logger = logging.getLogger(__name__)

_EXPIRY_BUFFER_SECONDS = 60


class OAuthTokenCacheBase(ABC):
"""Async cache abstraction for WebSocket OAuth tokens.

A cache key encodes the (session, provider) pair so that different users and
different OAuth providers are kept isolated. Implementations may be local
(in-process dict) or distributed (e.g. Redis-backed object store) to support
multi-replica deployments.
"""

@abstractmethod
async def get(self, key: str) -> AuthenticatedContext | None:
"""Return a cached, non-expired token, or None if absent or expired."""

@abstractmethod
async def set(self, key: str, ctx: AuthenticatedContext, expires_at: float | None) -> None:
"""Store *ctx* under *key*. Overwrites any existing entry."""

@abstractmethod
async def delete(self, key: str) -> None:
"""Remove *key* from the cache. A missing key is silently ignored."""


class InMemoryOAuthTokenCache(OAuthTokenCacheBase):
"""In-process dict-backed token cache.

Suitable for single-process deployments only. All state is lost on restart
and is not shared across replicas.
"""

def __init__(self) -> None:
self._store: dict[str, tuple[AuthenticatedContext, float | None]] = {}

async def get(self, key: str) -> AuthenticatedContext | None:
entry = self._store.get(key)
if entry is None:
return None
ctx, expires_at = entry
if expires_at is not None and time.time() >= expires_at - _EXPIRY_BUFFER_SECONDS:
del self._store[key]
return None
return ctx

async def set(self, key: str, ctx: AuthenticatedContext, expires_at: float | None) -> None:
self._store[key] = (ctx, expires_at)

async def delete(self, key: str) -> None:
self._store.pop(key, None)


class ObjectStoreOAuthTokenCache(OAuthTokenCacheBase):
"""Object-store-backed token cache.

Stores tokens as JSON blobs in a NAT object store (e.g. Redis, S3, MySQL),
which makes the cache durable and shared across all replicas.
"""

def __init__(self, object_store: ObjectStore) -> None:
self._object_store = object_store

async def get(self, key: str) -> AuthenticatedContext | None:
try:
item = await self._object_store.get_object(key)
except NoSuchKeyError:
return None
except Exception:
logger.exception("Failed to read OAuth token from object store (key=%s)", key)
return None

try:
payload = json.loads(item.data)
expires_at = payload.get("expires_at")
if expires_at is not None and time.time() >= float(expires_at) - _EXPIRY_BUFFER_SECONDS:
await self.delete(key)
return None
return AuthenticatedContext.model_validate(payload["ctx"])
except Exception:
logger.exception("Failed to deserialize OAuth token from object store (key=%s)", key)
return None

async def set(self, key: str, ctx: AuthenticatedContext, expires_at: float | None) -> None:
try:
payload = json.dumps({
"ctx": ctx.model_dump(mode="json"),
"expires_at": expires_at,
}).encode("utf-8")
metadata: dict[str, str] = {}
if expires_at is not None:
metadata["expires_at"] = str(expires_at)
item = ObjectStoreItem(data=payload,
content_type="application/json",
metadata=metadata if metadata else None)
await self._object_store.upsert_object(key, item)
except Exception:
logger.exception("Failed to store OAuth token in object store (key=%s)", key)

async def delete(self, key: str) -> None:
try:
await self._object_store.delete_object(key)
except NoSuchKeyError:
pass
except Exception:
logger.exception("Failed to delete OAuth token from object store (key=%s)", key)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
from nat.data_models.authentication import AuthenticatedContext
from nat.data_models.authentication import AuthFlowType
from nat.data_models.authentication import AuthProviderBaseConfig
from nat.data_models.interactive import _HumanPromptOAuthConsent
from nat.front_ends.fastapi.auth_flow_handlers.oauth_token_cache import OAuthTokenCacheBase
from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler

logger = logging.getLogger(__name__)
Expand All @@ -42,6 +44,7 @@ class FlowState:
verifier: str | None = None
client: AsyncOAuth2Client | None = None
config: OAuth2AuthCodeFlowProviderConfig | None = None
return_url: str | None = None


class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
Expand All @@ -50,12 +53,18 @@ def __init__(self,
add_flow_cb: Callable[[str, FlowState], Awaitable[None]],
remove_flow_cb: Callable[[str], Awaitable[None]],
web_socket_message_handler: WebSocketMessageHandler,
auth_timeout_seconds: float = 300.0):
auth_timeout_seconds: float = 300.0,
return_url: str | None = None,
token_cache: OAuthTokenCacheBase | None = None,
session_id: str | None = None):

self._add_flow_cb: Callable[[str, FlowState], Awaitable[None]] = add_flow_cb
self._remove_flow_cb: Callable[[str], Awaitable[None]] = remove_flow_cb
self._web_socket_message_handler: WebSocketMessageHandler = web_socket_message_handler
self._auth_timeout_seconds: float = auth_timeout_seconds
self._return_url: str | None = return_url
self._token_cache: OAuthTokenCacheBase | None = token_cache
self._session_id: str | None = session_id

async def authenticate(
self,
Expand All @@ -66,6 +75,17 @@ async def authenticate(

raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")

async def run_eager_auth(self, auth_providers: dict[str, AuthProviderBaseConfig]) -> None:
"""Run auth for every configured OAuth2 provider before the first user message.

Only providers with use_eager_auth option set in their config are processed.
Returns immediately if tokens are already cached. Otherwise triggers the OAuth
redirect so the user authenticates at page load rather than mid-workflow.
"""
for provider_config in auth_providers.values():
if isinstance(provider_config, OAuth2AuthCodeFlowProviderConfig) and provider_config.use_eager_auth:
await self.authenticate(provider_config, AuthFlowType.OAUTH2_AUTHORIZATION_CODE)

def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
try:
return AsyncOAuth2Client(client_id=config.client_id,
Expand Down Expand Up @@ -113,8 +133,18 @@ def _create_authorization_url(self,

async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:

if config.use_redirect_auth and self._return_url is None:
raise ValueError("Redirect-based authentication (use_redirect_auth=True) requires a return URL, "
"but none was configured. Pass return_url when constructing the flow handler.")

cached = await self._get_cached_token(config)
if cached is not None:
logger.debug("OAuth token cache hit for client_id=%s", config.client_id)
return cached

state = secrets.token_urlsafe(16)
flow_state = FlowState(config=config)
return_url = self._return_url if config.use_redirect_auth else None
flow_state = FlowState(config=config, return_url=return_url)

flow_state.client = self.create_oauth_client(config)

Expand All @@ -130,8 +160,8 @@ async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProvider
challenge=flow_state.challenge)

await self._add_flow_cb(state, flow_state)
await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
)
await self._web_socket_message_handler.create_websocket_message(
_HumanPromptOAuthConsent(text=authorization_url, use_redirect=config.use_redirect_auth))
try:
token = await asyncio.wait_for(flow_state.future, timeout=self._auth_timeout_seconds)
except TimeoutError as exc:
Expand All @@ -140,7 +170,30 @@ async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProvider

await self._remove_flow_cb(state)

return AuthenticatedContext(headers={"Authorization": f"Bearer {token['access_token']}"},
metadata={
"expires_at": token.get("expires_at"), "raw_token": token
})
ctx = AuthenticatedContext(headers={"Authorization": f"Bearer {token['access_token']}"},
metadata={
"expires_at": token.get("expires_at"), "raw_token": token
})
await self._store_token(config, ctx)
return ctx

def _token_cache_key(self, config: OAuth2AuthCodeFlowProviderConfig) -> str | None:
"""Return a cache key for this (session, provider) pair, or None if caching is unavailable."""
if not self._session_id or self._token_cache is None:
return None
return f"{self._session_id}:{config.client_id}:{config.token_url}"

async def _get_cached_token(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext | None:
"""Return a cached, non-expired token for *config*, or None."""
key = self._token_cache_key(config)
if key is None or self._token_cache is None:
return None
return await self._token_cache.get(key)

async def _store_token(self, config: OAuth2AuthCodeFlowProviderConfig, ctx: AuthenticatedContext) -> None:
"""Cache *ctx* for *config* if caching is available."""
key = self._token_cache_key(config)
if key is None or self._token_cache is None:
return
expires_at = ctx.metadata.get("expires_at") if ctx.metadata else None
await self._token_cache.set(key, ctx, expires_at)
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ class CrossOriginResourceSharing(BaseModel):
"request to '/static' and files will be served from the object store. The files will be served from the "
"object store at '/static/{file_name}'."))

oauth_token_store: ObjectStoreRef | None = Field(
default=None,
description=("Object store reference used to persist WebSocket OAuth tokens across replicas. "
"When set, tokens are stored in the named object store (e.g. Redis) so that "
"re-authentication is not required after pod restarts or when requests land on "
"different replicas. When unset, an in-process dict is used (single-replica only)."))

disable_legacy_routes: bool = Field(
default=False,
description="Disable the legacy routes for the FastAPI app. If True, the legacy routes are disabled.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from nat.utils.log_utils import setup_logging

from .auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler
from .auth_flow_handlers.oauth_token_cache import InMemoryOAuthTokenCache
from .auth_flow_handlers.oauth_token_cache import OAuthTokenCacheBase
from .auth_flow_handlers.oauth_token_cache import ObjectStoreOAuthTokenCache
from .auth_flow_handlers.websocket_flow_handler import FlowState
from .execution_store import ExecutionStore
from .fastapi_front_end_config import FastApiFrontEndConfig
Expand Down Expand Up @@ -202,6 +205,10 @@ def __init__(self, config: Config):
# Conversation handlers for WebSocket reconnection support
self._conversation_handlers: dict[str, WebSocketMessageHandler] = {}

# OAuth token cache — replaced with an ObjectStoreOAuthTokenCache in configure()
# when front_end.oauth_token_store is set, enabling cross-replica token sharing.
self._oauth_token_cache: OAuthTokenCacheBase = InMemoryOAuthTokenCache()

# Track session managers for each route
self._session_managers: list[SessionManager] = []

Expand Down Expand Up @@ -313,6 +320,11 @@ async def configure(self, app: FastAPI, builder: WorkflowBuilder):
# Do things like setting the base URL and global configuration options
app.root_path = self.front_end_config.root_path

if self.front_end_config.oauth_token_store is not None:
object_store = await builder.get_object_store_client(self.front_end_config.oauth_token_store)
self._oauth_token_cache = ObjectStoreOAuthTokenCache(object_store)
logger.debug("OAuth token cache backed by object store '%s'", self.front_end_config.oauth_token_store)

# Initialize evaluators for single-item evaluation
# TODO: we need config control over this as it's not always needed
await self.initialize_evaluators(self._config)
Expand Down
Loading