diff --git a/docs/source/components/auth/api-authentication.md b/docs/source/components/auth/api-authentication.md index a3405b1652..5ceb7182e0 100644 --- a/docs/source/components/auth/api-authentication.md +++ b/docs/source/components/auth/api-authentication.md @@ -115,6 +115,8 @@ authentication: | `redirect_uri` | The redirect URI for OAuth 2.0 authentication. Must match the registered redirect URI with the OAuth provider.| | `scopes` | List of permissions to the API provider (e.g., `read`, `write`). | | `use_pkce` | Whether to use PKCE (Proof Key for Code Exchange) in the OAuth 2.0 flow, defaults to `False` | +| `use_eager_auth` | Whether to trigger authentication at WebSocket connection time before the workflow requires credentials, defaults to `False`. When enabled, tokens are cached for the session to avoid re-authentication on reconnect. | +| `use_redirect_auth` | Whether to use a redirect-based flow or open the OAuth consent page in a popup window, defaults to `False` (popup) | | `authorization_kwargs` | Additional keyword arguments to include in the authorization request. | diff --git a/examples/front_ends/simple_auth/Dockerfile b/examples/front_ends/simple_auth/Dockerfile index f54601fdb6..8112511fc8 100644 --- a/examples/front_ends/simple_auth/Dockerfile +++ b/examples/front_ends/simple_auth/Dockerfile @@ -27,6 +27,12 @@ RUN apt-get update && apt-get install -y \ # Clone the OAuth2 server example RUN git clone https://github.com/authlib/example-oauth2-server.git oauth2-server +# Apply patches: add an explicit Cancel button to the authorize route and +# template so that declining the OAuth2 consent redirects back with +# error=access_denied instead of leaving the client app waiting indefinitely. +COPY patches/oauth2-server.patch /tmp/oauth2-server.patch +RUN patch -p1 -d /app/oauth2-server < /tmp/oauth2-server.patch + # Change to the OAuth2 server directory WORKDIR /app/oauth2-server diff --git a/examples/front_ends/simple_auth/patches/oauth2-server.patch b/examples/front_ends/simple_auth/patches/oauth2-server.patch new file mode 100644 index 0000000000..2dcf08f7dc --- /dev/null +++ b/examples/front_ends/simple_auth/patches/oauth2-server.patch @@ -0,0 +1,30 @@ +--- a/website/routes.py ++++ b/website/routes.py +@@ -105,7 +105,7 @@ + if not user and "username" in request.form: + username = request.form.get("username") + user = User.query.filter_by(username=username).first() +- if request.form["confirm"]: ++ if request.form.get("confirm") == "yes": + grant_user = user + else: + grant_user = None +--- a/website/templates/authorize.html ++++ b/website/templates/authorize.html +@@ -9,14 +9,11 @@ +
+- + {% if not user %} +

You haven't logged in. Log in with:

+
+ +
+ {% endif %} +
+- ++ ++ +
diff --git a/examples/front_ends/simple_auth/src/nat_simple_auth/configs/config.yml b/examples/front_ends/simple_auth/src/nat_simple_auth/configs/config.yml index d5bb47bb77..0e9e4e2a5c 100644 --- a/examples/front_ends/simple_auth/src/nat_simple_auth/configs/config.yml +++ b/examples/front_ends/simple_auth/src/nat_simple_auth/configs/config.yml @@ -57,6 +57,8 @@ authentication: client_id: ${NAT_OAUTH_CLIENT_ID} client_secret: ${NAT_OAUTH_CLIENT_SECRET} use_pkce: false + use_eager_auth: false + use_redirect_auth: false workflow: _type: react_agent diff --git a/packages/nvidia_nat_core/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py b/packages/nvidia_nat_core/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py index 65223665f8..fabd89ec69 100644 --- a/packages/nvidia_nat_core/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +++ b/packages/nvidia_nat_core/src/nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py @@ -35,6 +35,17 @@ class OAuth2AuthCodeFlowProviderConfig(AuthProviderBaseConfig, name="oauth2_auth use_pkce: bool = Field(default=False, description="Whether to use PKCE (Proof Key for Code Exchange) in the OAuth 2.0 flow.") + use_redirect_auth: bool = Field( + default=False, + description=("When False (default), the OAuth login page opens in a popup window and the originating page " + "remains open. When True, the browser navigates to the OAuth login page directly and is " + "redirected back after authentication completes.")) + + use_eager_auth: bool = Field( + default=False, + description=("When False (default), authentication is deferred until the workflow first requires " + "credentials. When True, authentication is triggered at WebSocket connection time.")) + authorization_kwargs: dict[str, str] | None = Field(description=("Additional keyword arguments for the " "authorization request."), default=None) diff --git a/packages/nvidia_nat_core/src/nat/data_models/interactive.py b/packages/nvidia_nat_core/src/nat/data_models/interactive.py index c76fcde15d..e63cc60d2d 100644 --- a/packages/nvidia_nat_core/src/nat/data_models/interactive.py +++ b/packages/nvidia_nat_core/src/nat/data_models/interactive.py @@ -161,6 +161,9 @@ class _HumanPromptOAuthConsent(HumanPromptBase): the consent flow. """ input_type: typing.Literal[HumanPromptModelType.OAUTH_CONSENT] = HumanPromptModelType.OAUTH_CONSENT + use_redirect: bool = Field(default=False, + description="When False the UI should open the OAuth URL in a popup window. " + "When True the UI should navigate the current tab to the OAuth URL.") class HumanPromptBinary(HumanPromptBase): diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/oauth_token_cache.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/oauth_token_cache.py new file mode 100644 index 0000000000..d9bb745ff1 --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/oauth_token_cache.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import time +from abc import ABC +from abc import abstractmethod + +from nat.data_models.authentication import AuthenticatedContext +from nat.data_models.object_store import NoSuchKeyError +from nat.object_store.interfaces import ObjectStore +from nat.object_store.models import ObjectStoreItem + +logger = logging.getLogger(__name__) + +_EXPIRY_BUFFER_SECONDS = 60 + + +class OAuthTokenCacheBase(ABC): + """Async cache abstraction for WebSocket OAuth tokens. + + A cache key encodes the (session, provider) pair so that different users and + different OAuth providers are kept isolated. Implementations may be local + (in-process dict) or distributed (e.g. Redis-backed object store) to support + multi-replica deployments. + """ + + @abstractmethod + async def get(self, key: str) -> AuthenticatedContext | None: + """Return a cached, non-expired token, or None if absent or expired.""" + + @abstractmethod + async def set(self, key: str, ctx: AuthenticatedContext, expires_at: float | None) -> None: + """Store *ctx* under *key*. Overwrites any existing entry.""" + + @abstractmethod + async def delete(self, key: str) -> None: + """Remove *key* from the cache. A missing key is silently ignored.""" + + +class InMemoryOAuthTokenCache(OAuthTokenCacheBase): + """In-process dict-backed token cache. + + Suitable for single-process deployments only. All state is lost on restart + and is not shared across replicas. + """ + + def __init__(self) -> None: + self._store: dict[str, tuple[AuthenticatedContext, float | None]] = {} + + async def get(self, key: str) -> AuthenticatedContext | None: + entry = self._store.get(key) + if entry is None: + return None + ctx, expires_at = entry + if expires_at is not None and time.time() >= expires_at - _EXPIRY_BUFFER_SECONDS: + del self._store[key] + return None + return ctx + + async def set(self, key: str, ctx: AuthenticatedContext, expires_at: float | None) -> None: + self._store[key] = (ctx, expires_at) + + async def delete(self, key: str) -> None: + self._store.pop(key, None) + + +class ObjectStoreOAuthTokenCache(OAuthTokenCacheBase): + """Object-store-backed token cache. + + Stores tokens as JSON blobs in a NAT object store (e.g. Redis, S3, MySQL), + which makes the cache durable and shared across all replicas. + """ + + def __init__(self, object_store: ObjectStore) -> None: + self._object_store = object_store + + async def get(self, key: str) -> AuthenticatedContext | None: + try: + item = await self._object_store.get_object(key) + except NoSuchKeyError: + return None + except Exception: + logger.exception("Failed to read OAuth token from object store (key=%s)", key) + return None + + try: + payload = json.loads(item.data) + expires_at = payload.get("expires_at") + if expires_at is not None and time.time() >= float(expires_at) - _EXPIRY_BUFFER_SECONDS: + await self.delete(key) + return None + return AuthenticatedContext.model_validate(payload["ctx"]) + except Exception: + logger.exception("Failed to deserialize OAuth token from object store (key=%s)", key) + return None + + async def set(self, key: str, ctx: AuthenticatedContext, expires_at: float | None) -> None: + try: + payload = json.dumps({ + "ctx": ctx.model_dump(mode="json"), + "expires_at": expires_at, + }).encode("utf-8") + metadata: dict[str, str] = {} + if expires_at is not None: + metadata["expires_at"] = str(expires_at) + item = ObjectStoreItem(data=payload, + content_type="application/json", + metadata=metadata if metadata else None) + await self._object_store.upsert_object(key, item) + except Exception: + logger.exception("Failed to store OAuth token in object store (key=%s)", key) + + async def delete(self, key: str) -> None: + try: + await self._object_store.delete_object(key) + except NoSuchKeyError: + pass + except Exception: + logger.exception("Failed to delete OAuth token from object store (key=%s)", key) diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py index f80af16110..8d6ae120b4 100644 --- a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py @@ -29,7 +29,9 @@ from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig from nat.data_models.authentication import AuthenticatedContext from nat.data_models.authentication import AuthFlowType +from nat.data_models.authentication import AuthProviderBaseConfig from nat.data_models.interactive import _HumanPromptOAuthConsent +from nat.front_ends.fastapi.auth_flow_handlers.oauth_token_cache import OAuthTokenCacheBase from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler logger = logging.getLogger(__name__) @@ -42,6 +44,7 @@ class FlowState: verifier: str | None = None client: AsyncOAuth2Client | None = None config: OAuth2AuthCodeFlowProviderConfig | None = None + return_url: str | None = None class WebSocketAuthenticationFlowHandler(FlowHandlerBase): @@ -50,12 +53,18 @@ def __init__(self, add_flow_cb: Callable[[str, FlowState], Awaitable[None]], remove_flow_cb: Callable[[str], Awaitable[None]], web_socket_message_handler: WebSocketMessageHandler, - auth_timeout_seconds: float = 300.0): + auth_timeout_seconds: float = 300.0, + return_url: str | None = None, + token_cache: OAuthTokenCacheBase | None = None, + session_id: str | None = None): self._add_flow_cb: Callable[[str, FlowState], Awaitable[None]] = add_flow_cb self._remove_flow_cb: Callable[[str], Awaitable[None]] = remove_flow_cb self._web_socket_message_handler: WebSocketMessageHandler = web_socket_message_handler self._auth_timeout_seconds: float = auth_timeout_seconds + self._return_url: str | None = return_url + self._token_cache: OAuthTokenCacheBase | None = token_cache + self._session_id: str | None = session_id async def authenticate( self, @@ -66,6 +75,17 @@ async def authenticate( raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.") + async def run_eager_auth(self, auth_providers: dict[str, AuthProviderBaseConfig]) -> None: + """Run auth for every configured OAuth2 provider before the first user message. + + Only providers with use_eager_auth option set in their config are processed. + Returns immediately if tokens are already cached. Otherwise triggers the OAuth + redirect so the user authenticates at page load rather than mid-workflow. + """ + for provider_config in auth_providers.values(): + if isinstance(provider_config, OAuth2AuthCodeFlowProviderConfig) and provider_config.use_eager_auth: + await self.authenticate(provider_config, AuthFlowType.OAUTH2_AUTHORIZATION_CODE) + def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client: try: return AsyncOAuth2Client(client_id=config.client_id, @@ -113,8 +133,18 @@ def _create_authorization_url(self, async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext: + if config.use_redirect_auth and self._return_url is None: + raise ValueError("Redirect-based authentication (use_redirect_auth=True) requires a return URL, " + "but none was configured. Pass return_url when constructing the flow handler.") + + cached = await self._get_cached_token(config) + if cached is not None: + logger.debug("OAuth token cache hit for client_id=%s", config.client_id) + return cached + state = secrets.token_urlsafe(16) - flow_state = FlowState(config=config) + return_url = self._return_url if config.use_redirect_auth else None + flow_state = FlowState(config=config, return_url=return_url) flow_state.client = self.create_oauth_client(config) @@ -130,8 +160,8 @@ async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProvider challenge=flow_state.challenge) await self._add_flow_cb(state, flow_state) - await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url) - ) + await self._web_socket_message_handler.create_websocket_message( + _HumanPromptOAuthConsent(text=authorization_url, use_redirect=config.use_redirect_auth)) try: token = await asyncio.wait_for(flow_state.future, timeout=self._auth_timeout_seconds) except TimeoutError as exc: @@ -140,7 +170,30 @@ async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProvider await self._remove_flow_cb(state) - return AuthenticatedContext(headers={"Authorization": f"Bearer {token['access_token']}"}, - metadata={ - "expires_at": token.get("expires_at"), "raw_token": token - }) + ctx = AuthenticatedContext(headers={"Authorization": f"Bearer {token['access_token']}"}, + metadata={ + "expires_at": token.get("expires_at"), "raw_token": token + }) + await self._store_token(config, ctx) + return ctx + + def _token_cache_key(self, config: OAuth2AuthCodeFlowProviderConfig) -> str | None: + """Return a cache key for this (session, provider) pair, or None if caching is unavailable.""" + if not self._session_id or self._token_cache is None: + return None + return f"{self._session_id}:{config.client_id}:{config.token_url}" + + async def _get_cached_token(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext | None: + """Return a cached, non-expired token for *config*, or None.""" + key = self._token_cache_key(config) + if key is None or self._token_cache is None: + return None + return await self._token_cache.get(key) + + async def _store_token(self, config: OAuth2AuthCodeFlowProviderConfig, ctx: AuthenticatedContext) -> None: + """Cache *ctx* for *config* if caching is available.""" + key = self._token_cache_key(config) + if key is None or self._token_cache is None: + return + expires_at = ctx.metadata.get("expires_at") if ctx.metadata else None + await self._token_cache.set(key, ctx, expires_at) diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_config.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_config.py index b13e6bab47..b2aaaae4ac 100644 --- a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_config.py +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_config.py @@ -319,6 +319,13 @@ class CrossOriginResourceSharing(BaseModel): "request to '/static' and files will be served from the object store. The files will be served from the " "object store at '/static/{file_name}'.")) + oauth_token_store: ObjectStoreRef | None = Field( + default=None, + description=("Object store reference used to persist WebSocket OAuth tokens across replicas. " + "When set, tokens are stored in the named object store (e.g. Redis) so that " + "re-authentication is not required after pod restarts or when requests land on " + "different replicas. When unset, an in-process dict is used (single-replica only).")) + disable_legacy_routes: bool = Field( default=False, description="Disable the legacy routes for the FastAPI app. If True, the legacy routes are disabled.") diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py index e9afa4f4ea..6bd1103b1f 100644 --- a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py @@ -36,6 +36,9 @@ from nat.utils.log_utils import setup_logging from .auth_flow_handlers.http_flow_handler import HTTPAuthenticationFlowHandler +from .auth_flow_handlers.oauth_token_cache import InMemoryOAuthTokenCache +from .auth_flow_handlers.oauth_token_cache import OAuthTokenCacheBase +from .auth_flow_handlers.oauth_token_cache import ObjectStoreOAuthTokenCache from .auth_flow_handlers.websocket_flow_handler import FlowState from .execution_store import ExecutionStore from .fastapi_front_end_config import FastApiFrontEndConfig @@ -202,6 +205,10 @@ def __init__(self, config: Config): # Conversation handlers for WebSocket reconnection support self._conversation_handlers: dict[str, WebSocketMessageHandler] = {} + # OAuth token cache — replaced with an ObjectStoreOAuthTokenCache in configure() + # when front_end.oauth_token_store is set, enabling cross-replica token sharing. + self._oauth_token_cache: OAuthTokenCacheBase = InMemoryOAuthTokenCache() + # Track session managers for each route self._session_managers: list[SessionManager] = [] @@ -313,6 +320,11 @@ async def configure(self, app: FastAPI, builder: WorkflowBuilder): # Do things like setting the base URL and global configuration options app.root_path = self.front_end_config.root_path + if self.front_end_config.oauth_token_store is not None: + object_store = await builder.get_object_store_client(self.front_end_config.oauth_token_store) + self._oauth_token_cache = ObjectStoreOAuthTokenCache(object_store) + logger.debug("OAuth token cache backed by object store '%s'", self.front_end_config.oauth_token_store) + # Initialize evaluators for single-item evaluation # TODO: we need config control over this as it's not always needed await self.initialize_evaluators(self._config) diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_cancelled.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_cancelled.py new file mode 100644 index 0000000000..d0f47cf67c --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_cancelled.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +AUTH_REDIRECT_CANCELLED_POPUP_HTML = """ + + + + Authorization Cancelled + + + +

Authorization cancelled. You may now close this window.

+ + +""" + +_AUTH_REDIRECT_CANCELLED_HTML_TEMPLATE = """\ + + + + Authorization Cancelled + + + +

Authorization cancelled. Redirecting…

+ + +""" + + +def build_auth_redirect_cancelled_html(return_url: str | None = None) -> str: + """Build the same-page authorization-cancelled HTML page. + + Args: + return_url: The URL to redirect to after cancellation. When + provided the page navigates there immediately with an ``oauth_auth_error`` + query parameter so the UI can detect the cancellation and avoid a + pre-auth redirect loop; otherwise it falls back to ``window.history.back()``. + + Returns: + An HTML string for the post-cancellation redirect page. + """ + safe_json = json.dumps(return_url).replace('<', '\\u003c').replace('>', '\\u003e').replace('/', '\\u002f') + return _AUTH_REDIRECT_CANCELLED_HTML_TEMPLATE.replace("RETURN_URL_PLACEHOLDER", safe_json) diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_error.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_error.py new file mode 100644 index 0000000000..1a9812139e --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_error.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +AUTH_REDIRECT_ERROR_HTML = """ + + + + Authentication Error + + + +

Authentication failed. You may now close this window.

+ + +""" + +_AUTH_REDIRECT_ERROR_HTML_TEMPLATE = """\ + + + + Authentication Error + + + +

Authentication failed. Redirecting…

+ + +""" + + +def build_auth_redirect_error_html(return_url: str | None = None) -> str: + """Build the redirect-based authentication error HTML page. + + Navigates back to the UI without the ``oauth_auth_completed`` query + parameter so the UI's error-message branch handles it. + + Args: + return_url: The UI origin to navigate back to. Falls back to + ``window.history.back()`` when not provided. + + Returns: + An HTML string for the post-error redirect page. + """ + safe_json = json.dumps(return_url).replace('<', '\\u003c').replace('>', '\\u003e').replace('/', '\\u002f') + return _AUTH_REDIRECT_ERROR_HTML_TEMPLATE.replace("RETURN_URL_PLACEHOLDER", safe_json) diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py index a14f18df98..4c8db89e11 100644 --- a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/html_snippets/auth_code_grant_success.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + AUTH_REDIRECT_SUCCESS_HTML = """ @@ -33,3 +35,43 @@ """ + +_AUTH_REDIRECT_SUCCESS_HTML_REDIRECT_TEMPLATE = """\ + + + + Authentication Complete + + + +

Authentication complete. Redirecting…

+ + +""" + + +def build_auth_redirect_success_html(return_url: str | None = None) -> str: + """Build the redirect-based authentication success HTML page. + + Args: + return_url: The URL to redirect to after successful authentication. When + provided the page navigates there immediately with an ``oauth_auth_completed`` + query parameter so the UI can distinguish a successful return from the user + pressing back; otherwise it falls back to ``window.history.back()``. + + Returns: + An HTML string for the post-authentication redirect page. + """ + safe_json = json.dumps(return_url).replace('<', '\\u003c').replace('>', '\\u003e').replace('/', '\\u002f') + return _AUTH_REDIRECT_SUCCESS_HTML_REDIRECT_TEMPLATE.replace("RETURN_URL_PLACEHOLDER", safe_json) diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/auth.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/auth.py index ce1fc01bba..89c2cf4640 100644 --- a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/auth.py +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/auth.py @@ -23,7 +23,12 @@ from fastapi import Request from fastapi.responses import HTMLResponse +from nat.front_ends.fastapi.html_snippets.auth_code_grant_cancelled import AUTH_REDIRECT_CANCELLED_POPUP_HTML +from nat.front_ends.fastapi.html_snippets.auth_code_grant_cancelled import build_auth_redirect_cancelled_html +from nat.front_ends.fastapi.html_snippets.auth_code_grant_error import AUTH_REDIRECT_ERROR_HTML +from nat.front_ends.fastapi.html_snippets.auth_code_grant_error import build_auth_redirect_error_html from nat.front_ends.fastapi.html_snippets.auth_code_grant_success import AUTH_REDIRECT_SUCCESS_HTML +from nat.front_ends.fastapi.html_snippets.auth_code_grant_success import build_auth_redirect_success_html if TYPE_CHECKING: from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker @@ -37,6 +42,7 @@ async def add_authorization_route(worker: "FastApiFrontEndPluginWorker", app: Fa async def redirect_uri(request: Request): """Handle the redirect URI for OAuth2 authentication.""" state = request.query_params.get("state") + error = request.query_params.get("error") async with worker._outstanding_flows_lock: if not state or state not in worker._outstanding_flows: @@ -44,6 +50,43 @@ async def redirect_uri(request: Request): flow_state = worker._outstanding_flows[state] + # The OAuth provider returned an error in the redirect. Two distinct cases: + # - "access_denied": user explicitly declined consent → cancellation UX. + # - anything else: provider/server error → error UX. + if error: + error_description = request.query_params.get("error_description", "") + await worker._remove_flow(state) + + if error == "access_denied": + logger.info("OAuth authorisation denied for state %s: %s (%s)", state, error, error_description) + if not flow_state.future.done(): + flow_state.future.set_exception( + RuntimeError(f"Authorisation denied: {error} ({error_description})")) + if flow_state.config and flow_state.config.use_redirect_auth: + return HTMLResponse(content=build_auth_redirect_cancelled_html(flow_state.return_url), + status_code=200, + headers={ + "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" + }) + return HTMLResponse(content=AUTH_REDIRECT_CANCELLED_POPUP_HTML, + status_code=200, + headers={ + "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" + }) + + logger.error("OAuth error for state %s: %s (%s)", state, error, error_description) + if not flow_state.future.done(): + flow_state.future.set_exception(RuntimeError(f"OAuth error: {error} ({error_description})")) + if flow_state.config and flow_state.config.use_redirect_auth: + error_html = build_auth_redirect_error_html(flow_state.return_url) + else: + error_html = AUTH_REDIRECT_ERROR_HTML + return HTMLResponse(content=error_html, + status_code=200, + headers={ + "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" + }) + config = flow_state.config verifier = flow_state.verifier client = flow_state.client @@ -60,27 +103,49 @@ async def redirect_uri(request: Request): if not flow_state.future.done(): flow_state.future.set_exception( RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})")) - return HTMLResponse(f"Authorization failed: {e.error}", - status_code=502, - headers={"Cache-Control": "no-cache"}) + if flow_state.config and flow_state.config.use_redirect_auth: + error_html = build_auth_redirect_error_html(flow_state.return_url) + else: + error_html = AUTH_REDIRECT_ERROR_HTML + return HTMLResponse(content=error_html, + status_code=200, + headers={ + "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" + }) except httpx.HTTPError as e: logger.error("Network error during token fetch for state %s: %s", state, e) if not flow_state.future.done(): flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}")) - return HTMLResponse("Network error during token exchange. Please try again.", - status_code=502, - headers={"Cache-Control": "no-cache"}) + if flow_state.config and flow_state.config.use_redirect_auth: + error_html = build_auth_redirect_error_html(flow_state.return_url) + else: + error_html = AUTH_REDIRECT_ERROR_HTML + return HTMLResponse(content=error_html, + status_code=200, + headers={ + "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" + }) except Exception as e: logger.error("Unexpected error during authentication for state %s: %s", state, e) if not flow_state.future.done(): flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}")) - return HTMLResponse("Authentication failed. Please try again.", - status_code=500, - headers={"Cache-Control": "no-cache"}) + if flow_state.config and flow_state.config.use_redirect_auth: + error_html = build_auth_redirect_error_html(flow_state.return_url) + else: + error_html = AUTH_REDIRECT_ERROR_HTML + return HTMLResponse(content=error_html, + status_code=200, + headers={ + "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" + }) finally: await worker._remove_flow(state) - return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML, + if flow_state.config and flow_state.config.use_redirect_auth: + success_html = build_auth_redirect_success_html(flow_state.return_url) + else: + success_html = AUTH_REDIRECT_SUCCESS_HTML + return HTMLResponse(content=success_html, status_code=200, headers={ "Content-Type": "text/html; charset=utf-8", "Cache-Control": "no-cache" diff --git a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/websocket.py b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/websocket.py index 20016c6f9b..0c6e1b49b7 100644 --- a/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/websocket.py +++ b/packages/nvidia_nat_core/src/nat/front_ends/fastapi/routes/websocket.py @@ -25,6 +25,7 @@ from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler from nat.runtime.session import SESSION_COOKIE_NAME from nat.runtime.session import SessionManager +from nat.runtime.user_manager import UserManager logger = logging.getLogger(__name__) @@ -32,6 +33,25 @@ _SAFE_SESSION_ID_RE = re.compile(r'^[A-Za-z0-9\-_.~]+$') +def _is_origin_allowed(origin: str | None, allowed_origins: list[str], allow_origin_regex: str | None) -> bool: + """Return True if *origin* should be treated as an allowed CORS origin. + + Mirrors the three-tier check used by Starlette's CORSMiddleware: + 1. Wildcard ``"*"`` in *allowed_origins* accepts any non-empty origin. + 2. Exact membership in *allowed_origins*. + 3. Full-string match against *allow_origin_regex* (when set). + """ + if not origin: + return False + if "*" in allowed_origins: + return True + if origin in allowed_origins: + return True + if allow_origin_regex and re.fullmatch(allow_origin_regex, origin): + return True + return False + + def websocket_endpoint(*, worker: Any, session_manager: SessionManager): """Build websocket endpoint handler with auth-flow integration.""" @@ -72,8 +92,24 @@ async def _websocket_endpoint(websocket: WebSocket): websocket.scope["headers"] = headers async with WebSocketMessageHandler(websocket, session_manager, worker.get_step_adaptor(), worker) as handler: - flow_handler = WebSocketAuthenticationFlowHandler(worker._add_flow, worker._remove_flow, handler) + origin = websocket.headers.get("origin") + allowed_origins = worker.front_end_config.cors.allow_origins or [] + allow_origin_regex = worker.front_end_config.cors.allow_origin_regex + return_url = origin if _is_origin_allowed(origin, allowed_origins, allow_origin_regex) else None + nat_session_id = UserManager._get_session_cookie(websocket) + flow_handler = WebSocketAuthenticationFlowHandler(worker._add_flow, + worker._remove_flow, + handler, + return_url=return_url, + token_cache=worker._oauth_token_cache, + session_id=nat_session_id) handler.set_flow_handler(flow_handler) + skip_eager_auth = websocket.query_params.get("skip_eager_auth") == "true" + if not skip_eager_auth: + try: + await flow_handler.run_eager_auth(worker._config.authentication) + except Exception as e: + logger.info("Pre-authentication did not complete: %s", e) await handler.run() return _websocket_endpoint diff --git a/packages/nvidia_nat_core/tests/nat/builder/test_interactive.py b/packages/nvidia_nat_core/tests/nat/builder/test_interactive.py index 89b63070b3..c458016d1d 100644 --- a/packages/nvidia_nat_core/tests/nat/builder/test_interactive.py +++ b/packages/nvidia_nat_core/tests/nat/builder/test_interactive.py @@ -25,6 +25,7 @@ from nat.data_models.interactive import HumanPromptText from nat.data_models.interactive import HumanResponseText from nat.data_models.interactive import InteractionPrompt +from nat.data_models.interactive import _HumanPromptOAuthConsent # ------------------------------------------------------------------------------ # Tests for Interactive Data Models @@ -151,3 +152,29 @@ def test_human_prompt_base_timeout_validation_gt_zero(): HumanPromptText(text="x", required=True, timeout=0) with pytest.raises(ValidationError): HumanPromptText(text="x", required=True, timeout=-1) + + +# ------------------------------------------------------------------------------ +# Tests for _HumanPromptOAuthConsent +# ------------------------------------------------------------------------------ + + +def test_human_prompt_oauth_consent_defaults(): + """_HumanPromptOAuthConsent defaults: input_type is OAUTH_CONSENT and use_redirect is False.""" + prompt = _HumanPromptOAuthConsent(text="https://auth.example.com/authorize") + assert prompt.input_type == HumanPromptModelType.OAUTH_CONSENT + assert prompt.use_redirect is False + + +def test_human_prompt_oauth_consent_use_redirect_true(): + """_HumanPromptOAuthConsent accepts use_redirect=True for redirect-based auth flow.""" + prompt = _HumanPromptOAuthConsent(text="https://auth.example.com/authorize", use_redirect=True) + assert prompt.use_redirect is True + assert prompt.input_type == HumanPromptModelType.OAUTH_CONSENT + + +def test_human_prompt_oauth_consent_text_preserved(): + """_HumanPromptOAuthConsent stores the authorization URL in the text field.""" + url = "https://auth.example.com/authorize?client_id=abc&state=xyz" + prompt = _HumanPromptOAuthConsent(text=url) + assert prompt.text == url diff --git a/packages/nvidia_nat_core/tests/nat/front_ends/auth_flow_handlers/test_websocket_flow_handler.py b/packages/nvidia_nat_core/tests/nat/front_ends/auth_flow_handlers/test_websocket_flow_handler.py index cf96b5ad14..c3f87aa690 100644 --- a/packages/nvidia_nat_core/tests/nat/front_ends/auth_flow_handlers/test_websocket_flow_handler.py +++ b/packages/nvidia_nat_core/tests/nat/front_ends/auth_flow_handlers/test_websocket_flow_handler.py @@ -14,6 +14,7 @@ # limitations under the License. import socket +import time from urllib.parse import parse_qs from urllib.parse import urlparse @@ -22,9 +23,12 @@ from httpx import ASGITransport from mock_oauth2_server import MockOAuth2Server +from nat.authentication.api_key.api_key_auth_provider_config import APIKeyAuthProviderConfig from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig +from nat.data_models.authentication import AuthenticatedContext from nat.data_models.authentication import AuthFlowType from nat.data_models.config import Config +from nat.front_ends.fastapi.auth_flow_handlers.oauth_token_cache import InMemoryOAuthTokenCache from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import WebSocketAuthenticationFlowHandler from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker from nat.test.functions import EchoFunctionConfig @@ -39,6 +43,33 @@ def _free_port() -> int: return s.getsockname()[1] +async def _complete_oauth_redirect(auth_url: str, mock_server: MockOAuth2Server, outstanding_flows: dict): + """Hit the mock OAuth server, parse the redirect, fetch a token, and resolve the flow future.""" + async with httpx.AsyncClient( + transport=ASGITransport(app=mock_server._app), + base_url="http://testserver", + follow_redirects=False, + timeout=10, + ) as client: + r = await client.get(auth_url) + assert r.status_code == 302 + redirect_url = r.headers["location"] + + qs = parse_qs(urlparse(redirect_url).query) + code = qs["code"][0] + state = qs["state"][0] + + flow_state = outstanding_flows[state] + token = await flow_state.client.fetch_token( + url=flow_state.config.token_url, + code=code, + code_verifier=flow_state.verifier, + state=state, + ) + flow_state.future.set_result(token) + return flow_state + + class _AuthHandler(WebSocketAuthenticationFlowHandler): """ Override just one factory so the OAuth2 client talks to our in‑process @@ -107,6 +138,7 @@ async def test_websocket_oauth2_flow(monkeypatch, mock_server, tmp_path): # ----------------- dummy WebSocket “UI” handler --------------------- # opened: list[str] = [] + received_messages: list = [] class _DummyWSHandler: # minimal stand‑in for the UI layer @@ -115,32 +147,8 @@ def set_flow_handler(self, _): # called by worker – ignore async def create_websocket_message(self, msg): opened.append(msg.text) # record the auth URL - - # 1) ── Hit /oauth/authorize on the mock server ─────────── # - async with httpx.AsyncClient( - transport=ASGITransport(app=mock_server._app), - base_url="http://testserver", - follow_redirects=False, - timeout=10, - ) as client: - r = await client.get(msg.text) - assert r.status_code == 302 - redirect_url = r.headers["location"] - - # 2) ── Extract `code` and `state` from redirect URL ─────── # - qs = parse_qs(urlparse(redirect_url).query) - code = qs["code"][0] - state = qs["state"][0] - - # 3) ── Fetch token directly & resolve the Future in‑loop ── # - flow_state = worker._outstanding_flows[state] - token = await flow_state.client.fetch_token( - url=flow_state.config.token_url, - code=code, - code_verifier=flow_state.verifier, - state=state, - ) - flow_state.future.set_result(token) + received_messages.append(msg) + await _complete_oauth_redirect(msg.text, mock_server, worker._outstanding_flows) # ----------------- authentication handler instance ------------------ # ws_handler = _AuthHandler( @@ -168,6 +176,7 @@ async def create_websocket_message(self, msg): # ----------------- assertions -------------------------------------- # assert opened, "The authorization URL was never emitted." + assert received_messages[0].use_redirect is False, "Default use_redirect_auth should emit use_redirect=False" token_val = ctx.headers["Authorization"].split()[1] assert token_val in mock_server.tokens, "token not issued by mock server" @@ -175,6 +184,112 @@ async def create_websocket_message(self, msg): assert worker._outstanding_flows == {} +# --------------------------------------------------------------------------- # +# use_redirect_auth=True test # +# --------------------------------------------------------------------------- # +@pytest.mark.usefixtures("set_nat_config_file_env_var") +async def test_websocket_oauth2_flow_no_popup(monkeypatch, mock_server, tmp_path): + """Verify that use_redirect_auth=True sends use_redirect=True in the consent prompt and + propagates return_url into FlowState.""" + + redirect_port = _free_port() + + mock_server.register_client( + client_id="cid", + client_secret="secret", + redirect_base=f"http://localhost:{redirect_port}", + ) + + cfg_nat = Config(workflow=EchoFunctionConfig()) + worker = FastApiFrontEndPluginWorker(cfg_nat) + add_flow = worker._add_flow + remove_flow = worker._remove_flow + + received_messages: list = [] + captured_flow_states: list = [] + + class _DummyWSHandler: + + def set_flow_handler(self, _): + return + + async def create_websocket_message(self, msg): + received_messages.append(msg) + flow_state = await _complete_oauth_redirect(msg.text, mock_server, worker._outstanding_flows) + captured_flow_states.append(flow_state) + + ws_handler = _AuthHandler( + oauth_server=mock_server, + add_flow_cb=add_flow, + remove_flow_cb=remove_flow, + web_socket_message_handler=_DummyWSHandler(), + return_url="http://localhost:3000", + ) + + cfg_flow = OAuth2AuthCodeFlowProviderConfig( + client_id="cid", + client_secret="secret", + authorization_url="http://testserver/oauth/authorize", + token_url="http://testserver/oauth/token", + scopes=["read"], + use_pkce=True, + redirect_uri=f"http://localhost:{redirect_port}/auth/redirect", + use_redirect_auth=True, + ) + + monkeypatch.setattr("click.echo", lambda *_: None, raising=True) + + ctx = await ws_handler.authenticate(cfg_flow, AuthFlowType.OAUTH2_AUTHORIZATION_CODE) + + assert received_messages, "The authorization URL was never emitted." + assert received_messages[0].use_redirect is True, "use_redirect_auth=True should emit use_redirect=True" + assert captured_flow_states[0].return_url == "http://localhost:3000" + token_val = ctx.headers["Authorization"].split()[1] + assert token_val in mock_server.tokens, "token not issued by mock server" + assert worker._outstanding_flows == {} + + +# --------------------------------------------------------------------------- # +# use_redirect_auth=True without return_url guard # +# --------------------------------------------------------------------------- # +@pytest.mark.usefixtures("set_nat_config_file_env_var") +async def test_websocket_oauth2_flow_redirect_without_return_url(monkeypatch): + """Verify that use_redirect_auth=True with no return_url raises ValueError immediately.""" + + cfg_nat = Config(workflow=EchoFunctionConfig()) + worker = FastApiFrontEndPluginWorker(cfg_nat) + + class _DummyWSHandler: + + def set_flow_handler(self, _): + return + + async def create_websocket_message(self, msg): + pass + + ws_handler = WebSocketAuthenticationFlowHandler( + add_flow_cb=worker._add_flow, + remove_flow_cb=worker._remove_flow, + web_socket_message_handler=_DummyWSHandler(), + return_url=None, + ) + + cfg_flow = OAuth2AuthCodeFlowProviderConfig( + client_id="cid", + client_secret="secret", + authorization_url="http://testserver/oauth/authorize", + token_url="http://testserver/oauth/token", + scopes=["read"], + redirect_uri="http://localhost:8000/auth/redirect", + use_redirect_auth=True, + ) + + monkeypatch.setattr("click.echo", lambda *_: None, raising=True) + + with pytest.raises(ValueError, match="return URL"): + await ws_handler.authenticate(cfg_flow, AuthFlowType.OAUTH2_AUTHORIZATION_CODE) + + # --------------------------------------------------------------------------- # # Error Recovery Tests # # --------------------------------------------------------------------------- # @@ -223,3 +338,212 @@ async def create_websocket_message(self, msg): # Verify timeout RuntimeError is raised (demonstrates partial error handling) error_message = str(exc_info.value) assert "Authentication flow timed out" in error_message + + +# --------------------------------------------------------------------------- # +# Token-cache unit tests # +# --------------------------------------------------------------------------- # + + +@pytest.fixture(name="noop_handler") +def noop_handler_fixture(): + """Minimal handler with no-op callbacks for cache unit tests.""" + + async def _noop_add(state, flow_state): + pass + + async def _noop_remove(state): + pass + + class _NullWSHandler: + + async def create_websocket_message(self, msg): + pass + + return WebSocketAuthenticationFlowHandler( + add_flow_cb=_noop_add, + remove_flow_cb=_noop_remove, + web_socket_message_handler=_NullWSHandler(), + ) + + +@pytest.fixture(name="minimal_oauth_config") +def minimal_oauth_config_fixture(): + return OAuth2AuthCodeFlowProviderConfig( + client_id="test-client", + client_secret="test-secret", + authorization_url="http://auth.example.com/authorize", + token_url="http://auth.example.com/token", + redirect_uri="http://localhost:8000/callback", + ) + + +def test_token_cache_key_returns_none_without_required_attrs(noop_handler, minimal_oauth_config): + """_token_cache_key returns None when session_id or token_cache is absent.""" + noop_handler._token_cache = InMemoryOAuthTokenCache() + noop_handler._session_id = None + assert noop_handler._token_cache_key(minimal_oauth_config) is None + + noop_handler._token_cache = None + noop_handler._session_id = "sess-1" + assert noop_handler._token_cache_key(minimal_oauth_config) is None + + +def test_token_cache_key_format(noop_handler, minimal_oauth_config): + """_token_cache_key returns '{session_id}:{client_id}:{token_url}' when both are present.""" + noop_handler._token_cache = InMemoryOAuthTokenCache() + noop_handler._session_id = "sess-1" + key = noop_handler._token_cache_key(minimal_oauth_config) + assert key == f"sess-1:{minimal_oauth_config.client_id}:{minimal_oauth_config.token_url}" + + +async def test_get_cached_token_miss(noop_handler, minimal_oauth_config): + """_get_cached_token returns None when the cache has no entry for the config.""" + noop_handler._token_cache = InMemoryOAuthTokenCache() + noop_handler._session_id = "sess-1" + assert await noop_handler._get_cached_token(minimal_oauth_config) is None + + +@pytest.mark.parametrize("expires_at,expect_hit", + [ + pytest.param(None, True, id="no_expiry"), + pytest.param(time.time() + 3600, True, id="future"), + pytest.param(time.time() - 1, False, id="past"), + pytest.param(time.time() + 30, False, id="within_buffer"), + ]) +async def test_get_cached_token_expiry(noop_handler, minimal_oauth_config, expires_at, expect_hit): + """_get_cached_token returns the context when valid and evicts it when expired or within the 60s buffer.""" + ctx = AuthenticatedContext(headers={"Authorization": "Bearer tok"}, metadata={}) + cache = InMemoryOAuthTokenCache() + noop_handler._token_cache = cache + noop_handler._session_id = "sess-1" + key = noop_handler._token_cache_key(minimal_oauth_config) + cache._store[key] = (ctx, expires_at) + result = await noop_handler._get_cached_token(minimal_oauth_config) + if expect_hit: + assert result is ctx + else: + assert result is None + assert key not in cache._store + + +async def test_store_token_writes_correctly(noop_handler, minimal_oauth_config): + """_store_token writes the context to the cache under the expected key.""" + cache = InMemoryOAuthTokenCache() + noop_handler._token_cache = cache + noop_handler._session_id = "sess-1" + expires = 9999999999.0 + ctx = AuthenticatedContext(headers={"Authorization": "Bearer tok"}, metadata={"expires_at": expires}) + await noop_handler._store_token(minimal_oauth_config, ctx) + key = noop_handler._token_cache_key(minimal_oauth_config) + assert key in cache._store + stored_ctx, stored_expires = cache._store[key] + assert stored_ctx is ctx + assert stored_expires == expires + + +# --------------------------------------------------------------------------- # +# Token-cache integration: second authenticate() returns from cache # +# --------------------------------------------------------------------------- # +@pytest.mark.usefixtures("set_nat_config_file_env_var") +async def test_authenticate_second_call_uses_cache(monkeypatch, mock_server, tmp_path): + """After a successful flow the token is cached; a second call must not trigger OAuth again.""" + + redirect_port = _free_port() + mock_server.register_client( + client_id="cid", + client_secret="secret", + redirect_base=f"http://localhost:{redirect_port}", + ) + + cfg_nat = Config(workflow=EchoFunctionConfig()) + worker = FastApiFrontEndPluginWorker(cfg_nat) + message_count = [0] + + class _DummyWSHandler: + + def set_flow_handler(self, _): + return + + async def create_websocket_message(self, msg): + message_count[0] += 1 + await _complete_oauth_redirect(msg.text, mock_server, worker._outstanding_flows) + + token_cache = InMemoryOAuthTokenCache() + ws_handler = _AuthHandler( + oauth_server=mock_server, + add_flow_cb=worker._add_flow, + remove_flow_cb=worker._remove_flow, + web_socket_message_handler=_DummyWSHandler(), + token_cache=token_cache, + session_id="test-session", + ) + + cfg_flow = OAuth2AuthCodeFlowProviderConfig( + client_id="cid", + client_secret="secret", + authorization_url="http://testserver/oauth/authorize", + token_url="http://testserver/oauth/token", + scopes=["read"], + use_pkce=True, + redirect_uri=f"http://localhost:{redirect_port}/auth/redirect", + ) + + monkeypatch.setattr("click.echo", lambda *_: None, raising=True) + + ctx1 = await ws_handler.authenticate(cfg_flow, AuthFlowType.OAUTH2_AUTHORIZATION_CODE) + assert message_count[0] == 1, "OAuth flow should have run exactly once" + assert token_cache._store, "Token must be stored after first auth" + + ctx2 = await ws_handler.authenticate(cfg_flow, AuthFlowType.OAUTH2_AUTHORIZATION_CODE) + assert message_count[0] == 1, "Second authenticate() must return from cache without triggering OAuth" + assert ctx2.headers["Authorization"] == ctx1.headers["Authorization"] + + +# --------------------------------------------------------------------------- # +# run_eager_auth tests # +# --------------------------------------------------------------------------- # +async def test_run_eager_auth_skips_non_oauth2_providers(noop_handler): + """run_eager_auth is a no-op for non-OAuth2 providers such as APIKeyAuthProviderConfig.""" + api_key_config = APIKeyAuthProviderConfig(raw_key="my-api-key-value") + await noop_handler.run_eager_auth({"my_api_key": api_key_config}) + + +async def test_run_eager_auth_skips_oauth2_provider_flag_false(noop_handler, minimal_oauth_config): + """run_eager_auth does not trigger auth for OAuth2 providers with use_eager_auth=False (the default).""" + # minimal_oauth_config has use_eager_auth=False (the default); if the guard were absent this would hang + await noop_handler.run_eager_auth({"my_provider": minimal_oauth_config}) + + +async def test_run_eager_auth_uses_cached_token(minimal_oauth_config): + """run_eager_auth returns immediately without calling create_websocket_message on a cache hit.""" + + async def _noop_add(state, flow_state): + pass + + async def _noop_remove(state): + pass + + message_count = [0] + + class _CountingWSHandler: + + async def create_websocket_message(self, msg): + message_count[0] += 1 + + ctx = AuthenticatedContext(headers={"Authorization": "Bearer cached-tok"}, metadata={"expires_at": None}) + cache = InMemoryOAuthTokenCache() + # Enable use_eager_auth so the cache lookup is actually reached + active_config = minimal_oauth_config.model_copy(update={"use_eager_auth": True}) + handler = WebSocketAuthenticationFlowHandler( + add_flow_cb=_noop_add, + remove_flow_cb=_noop_remove, + web_socket_message_handler=_CountingWSHandler(), + token_cache=cache, + session_id="sess-1", + ) + key = handler._token_cache_key(active_config) + cache._store[key] = (ctx, time.time() + 3600) + + await handler.run_eager_auth({"my_provider": active_config}) + assert message_count[0] == 0, "run_eager_auth must not trigger OAuth when token is cached" diff --git a/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_cancelled.py b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_cancelled.py new file mode 100644 index 0000000000..98bd896134 --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_cancelled.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from nat.front_ends.fastapi.html_snippets.auth_code_grant_cancelled import AUTH_REDIRECT_CANCELLED_POPUP_HTML +from nat.front_ends.fastapi.html_snippets.auth_code_grant_cancelled import build_auth_redirect_cancelled_html + + +def _safe_json(value: str | None) -> str: + """Replicate the HTML-safe JSON encoding used by the snippet builders.""" + return json.dumps(value).replace('<', '\\u003c').replace('>', '\\u003e').replace('/', '\\u002f') + + +def test_auth_redirect_cancelled_popup_html_notifies_and_closes(): + """AUTH_REDIRECT_CANCELLED_POPUP_HTML (popup variant) posts AUTH_CANCELLED and closes the window.""" + assert "AUTH_CANCELLED" in AUTH_REDIRECT_CANCELLED_POPUP_HTML + assert "window.opener?.postMessage" in AUTH_REDIRECT_CANCELLED_POPUP_HTML + assert "window.close()" in AUTH_REDIRECT_CANCELLED_POPUP_HTML + + +def test_build_auth_redirect_cancelled_html_with_return_url(): + """build_auth_redirect_cancelled_html embeds a safe JSON-encoded return URL in the script.""" + return_url = "http://localhost:3000" + result = build_auth_redirect_cancelled_html(return_url) + assert _safe_json(return_url) in result + assert "window.location.replace" in result + assert "window.history.back()" in result + + +def test_build_auth_redirect_cancelled_html_without_return_url(): + """build_auth_redirect_cancelled_html falls back to window.history.back() when return_url is None.""" + result = build_auth_redirect_cancelled_html(None) + assert _safe_json(None) in result + assert "window.history.back()" in result + + +def test_build_auth_redirect_cancelled_html_no_oauth_auth_completed_param(): + """build_auth_redirect_cancelled_html must NOT add oauth_auth_completed so the UI handles cancellation.""" + result = build_auth_redirect_cancelled_html("http://localhost:3000") + assert "oauth_auth_completed" not in result + + +def test_build_auth_redirect_cancelled_html_url_characters_escaped(): + """build_auth_redirect_cancelled_html HTML-escapes <, >, and / in the JSON value.""" + return_url = "http://example.com/path?foo=bar&baz=" + result = build_auth_redirect_cancelled_html(return_url) + assert _safe_json(return_url) in result + assert "\\u003c" in result + assert "\\u003e" in result + assert "\\u002f" in result + assert "" not in result + + +def test_build_auth_redirect_cancelled_html_script_tag_cannot_break_out(): + """A sequence in the URL cannot terminate the enclosing script block.""" + return_url = "http://evil.com/" + result = build_auth_redirect_cancelled_html(return_url) + # The injected is escaped; only the template's own closing tag remains + assert result.count("") == 1 + assert _safe_json(return_url) in result diff --git a/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_error.py b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_error.py new file mode 100644 index 0000000000..f990a44962 --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_error.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from nat.front_ends.fastapi.html_snippets.auth_code_grant_error import AUTH_REDIRECT_ERROR_HTML +from nat.front_ends.fastapi.html_snippets.auth_code_grant_error import build_auth_redirect_error_html + + +def _safe_json(value: str | None) -> str: + """Replicate the HTML-safe JSON encoding used by the snippet builders.""" + return json.dumps(value).replace('<', '\\u003c').replace('>', '\\u003e').replace('/', '\\u002f') + + +def test_auth_redirect_error_html_popup_notifies_and_closes(): + """AUTH_REDIRECT_ERROR_HTML (popup variant) posts AUTH_ERROR and closes the window.""" + assert "AUTH_ERROR" in AUTH_REDIRECT_ERROR_HTML + assert "window.opener?.postMessage" in AUTH_REDIRECT_ERROR_HTML + assert "window.close()" in AUTH_REDIRECT_ERROR_HTML + + +def test_build_auth_redirect_error_html_with_return_url(): + """build_auth_redirect_error_html embeds a safe JSON-encoded return URL in the script.""" + return_url = "http://localhost:3000" + result = build_auth_redirect_error_html(return_url) + assert _safe_json(return_url) in result + assert "window.location.replace" in result + assert "window.history.back()" in result + + +def test_build_auth_redirect_error_html_without_return_url(): + """build_auth_redirect_error_html falls back to window.history.back() when return_url is None.""" + result = build_auth_redirect_error_html(None) + assert _safe_json(None) in result + assert "window.history.back()" in result + + +def test_build_auth_redirect_error_html_no_oauth_auth_completed_param(): + """build_auth_redirect_error_html must NOT add oauth_auth_completed so the UI handles the error.""" + result = build_auth_redirect_error_html("http://localhost:3000") + assert "oauth_auth_completed" not in result + + +def test_build_auth_redirect_error_html_url_characters_escaped(): + """build_auth_redirect_error_html HTML-escapes <, >, and / in the JSON value.""" + return_url = "http://example.com/path?foo=bar&baz=" + result = build_auth_redirect_error_html(return_url) + assert _safe_json(return_url) in result + assert "\\u003c" in result + assert "\\u003e" in result + assert "\\u002f" in result + assert "" not in result + + +def test_build_auth_redirect_error_html_script_tag_cannot_break_out(): + """A sequence in the URL cannot terminate the enclosing script block.""" + return_url = "http://evil.com/" + result = build_auth_redirect_error_html(return_url) + assert result.count("") == 1 + assert _safe_json(return_url) in result diff --git a/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_success.py b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_success.py new file mode 100644 index 0000000000..8c93364b63 --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_code_grant_success.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from nat.front_ends.fastapi.html_snippets.auth_code_grant_success import AUTH_REDIRECT_SUCCESS_HTML +from nat.front_ends.fastapi.html_snippets.auth_code_grant_success import build_auth_redirect_success_html + + +def _safe_json(value: str | None) -> str: + """Replicate the HTML-safe JSON encoding used by the snippet builders.""" + return json.dumps(value).replace('<', '\\u003c').replace('>', '\\u003e').replace('/', '\\u002f') + + +def test_auth_redirect_success_html_popup_closes_window(): + """AUTH_REDIRECT_SUCCESS_HTML (popup variant) notifies opener and closes the window.""" + assert "window.opener?.postMessage" in AUTH_REDIRECT_SUCCESS_HTML + assert "window.close()" in AUTH_REDIRECT_SUCCESS_HTML + + +def test_build_auth_redirect_success_html_with_return_url(): + """build_auth_redirect_success_html embeds a safe JSON-encoded return URL in the script.""" + return_url = "http://localhost:3000" + result = build_auth_redirect_success_html(return_url) + assert _safe_json(return_url) in result + assert "window.location.replace" in result + assert "window.history.back()" in result + assert "oauth_auth_completed" in result + + +def test_build_auth_redirect_success_html_without_return_url(): + """build_auth_redirect_success_html uses window.history.back() when return_url is None.""" + result = build_auth_redirect_success_html(None) + assert _safe_json(None) in result + assert "window.history.back()" in result + + +def test_build_auth_redirect_success_html_url_characters_escaped(): + """build_auth_redirect_success_html HTML-escapes <, >, and / in the JSON value.""" + return_url = "http://example.com/path?foo=bar&baz=" + result = build_auth_redirect_success_html(return_url) + assert _safe_json(return_url) in result + assert "\\u003c" in result + assert "\\u003e" in result + assert "\\u002f" in result + assert "" not in result + + +def test_build_auth_redirect_success_html_script_tag_cannot_break_out(): + """A sequence in the URL cannot terminate the enclosing script block.""" + return_url = "http://evil.com/" + result = build_auth_redirect_success_html(return_url) + # The injected is escaped; only the template's own closing tag remains + assert result.count("") == 1 + assert _safe_json(return_url) in result diff --git a/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_redirect_route.py b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_redirect_route.py new file mode 100644 index 0000000000..49df6a3acc --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_auth_redirect_route.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from fastapi import FastAPI +from httpx import ASGITransport +from httpx import AsyncClient + +from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig +from nat.front_ends.fastapi.auth_flow_handlers.websocket_flow_handler import FlowState +from nat.front_ends.fastapi.routes.auth import add_authorization_route + +_CALLBACK_PATH = "/auth/redirect" +_RETURN_URL = "http://localhost:3000" + +_POPUP_CONFIG = OAuth2AuthCodeFlowProviderConfig( + client_id="cid", + client_secret="secret", + authorization_url="http://testserver/oauth/authorize", + token_url="http://testserver/oauth/token", + redirect_uri="http://localhost:8000/auth/redirect", + use_redirect_auth=False, +) + +_REDIRECT_CONFIG = OAuth2AuthCodeFlowProviderConfig( + client_id="cid", + client_secret="secret", + authorization_url="http://testserver/oauth/authorize", + token_url="http://testserver/oauth/token", + redirect_uri="http://localhost:8000/auth/redirect", + use_redirect_auth=True, +) + + +def _make_worker(flow_state: FlowState): + flows = {"teststate": flow_state} + + class _Worker: + front_end_config = type("cfg", (), {"oauth2_callback_path": _CALLBACK_PATH})() + _outstanding_flows_lock = asyncio.Lock() + _outstanding_flows = flows + + async def _remove_flow(self, state: str) -> None: + flows.pop(state, None) + + return _Worker() + + +async def _get(worker, params: dict): + app = FastAPI() + await add_authorization_route(worker, app) + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: + return await client.get(_CALLBACK_PATH, params=params) + + +async def test_invalid_state_returns_400(): + """An unknown state is rejected with 400 before any flow processing.""" + flow_state = FlowState(config=_POPUP_CONFIG) + worker = _make_worker(flow_state) + response = await _get(worker, {"state": "no-such-state", "code": "abc"}) + assert response.status_code == 400 + + +async def test_access_denied_popup_returns_cancelled_html(): + """error=access_denied in popup mode returns the cancelled popup HTML.""" + flow_state = FlowState(config=_POPUP_CONFIG, return_url=None) + worker = _make_worker(flow_state) + response = await _get(worker, {"state": "teststate", "error": "access_denied"}) + assert response.status_code == 200 + assert "AUTH_CANCELLED" in response.text + assert "AUTH_ERROR" not in response.text + + +async def test_access_denied_popup_with_return_url_still_returns_cancelled_html(): + """error=access_denied in popup mode uses popup HTML even when return_url is set.""" + flow_state = FlowState(config=_POPUP_CONFIG, return_url=_RETURN_URL) + worker = _make_worker(flow_state) + response = await _get(worker, {"state": "teststate", "error": "access_denied"}) + assert response.status_code == 200 + assert "AUTH_CANCELLED" in response.text + assert "AUTH_ERROR" not in response.text + assert _RETURN_URL.replace("/", "\\u002f") not in response.text + assert "oauth_auth_completed" not in response.text + + +async def test_access_denied_redirect_returns_cancelled_html(): + """error=access_denied in redirect mode returns the redirect-back cancelled page.""" + flow_state = FlowState(config=_REDIRECT_CONFIG, return_url=_RETURN_URL) + worker = _make_worker(flow_state) + response = await _get(worker, {"state": "teststate", "error": "access_denied"}) + assert response.status_code == 200 + assert _RETURN_URL.replace("/", "\\u002f") in response.text + assert "AUTH_CANCELLED" not in response.text + assert "oauth_auth_completed" not in response.text + + +async def test_access_denied_sets_cancellation_exception_on_future(): + """error=access_denied resolves the future with an 'Authorisation denied' exception.""" + flow_state = FlowState(config=_POPUP_CONFIG, return_url=None) + worker = _make_worker(flow_state) + await _get(worker, {"state": "teststate", "error": "access_denied"}) + assert flow_state.future.done() + assert "Authorisation denied" in str(flow_state.future.exception()) + assert "access_denied" in str(flow_state.future.exception()) + + +async def test_provider_error_popup_returns_error_html(): + """Non-access_denied errors in popup mode return error HTML, not cancelled HTML.""" + flow_state = FlowState(config=_POPUP_CONFIG, return_url=None) + worker = _make_worker(flow_state) + response = await _get(worker, {"state": "teststate", "error": "server_error"}) + assert response.status_code == 200 + assert "AUTH_ERROR" in response.text + assert "AUTH_CANCELLED" not in response.text + + +async def test_provider_error_redirect_returns_error_html(): + """Non-access_denied errors in redirect mode redirect back without oauth_auth_completed.""" + flow_state = FlowState(config=_REDIRECT_CONFIG, return_url=_RETURN_URL) + worker = _make_worker(flow_state) + response = await _get(worker, {"state": "teststate", "error": "server_error"}) + assert response.status_code == 200 + assert _RETURN_URL.replace("/", "\\u002f") in response.text + assert "AUTH_CANCELLED" not in response.text + assert "oauth_auth_completed" not in response.text + + +async def test_provider_error_sets_oauth_error_exception_on_future(): + """Non-access_denied errors resolve the future with an 'OAuth error' exception including the code.""" + flow_state = FlowState(config=_POPUP_CONFIG, return_url=None) + worker = _make_worker(flow_state) + await _get(worker, {"state": "teststate", "error": "server_error", "error_description": "internal"}) + assert flow_state.future.done() + assert "OAuth error" in str(flow_state.future.exception()) + assert "server_error" in str(flow_state.future.exception()) + assert "internal" in str(flow_state.future.exception()) diff --git a/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_websocket_route_origin.py b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_websocket_route_origin.py new file mode 100644 index 0000000000..e1f17c0a35 --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/front_ends/fastapi/test_websocket_route_origin.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nat.front_ends.fastapi.routes.websocket import _is_origin_allowed + + +@pytest.mark.parametrize( + "origin,allowed_origins,allow_origin_regex,expected", + [ + # Exact match + ("http://localhost:3000", ["http://localhost:3000"], None, True), + # Not in list, no regex + ("http://evil.com", ["http://localhost:3000"], None, False), + # Wildcard accepts any non-empty origin + ("http://anything.example.com", ["*"], None, True), + # Wildcard with multiple entries + ("http://foo.com", ["http://bar.com", "*"], None, True), + # Regex match + ("http://app.example.com", [], r"https?://[a-z]+\.example\.com", True), + # Regex no match + ("http://evil.com", [], r"https?://[a-z]+\.example\.com", False), + # Regex with partial string does not match (fullmatch) + ("http://app.example.com/extra", [], r"https?://[a-z]+\.example\.com", False), + # None origin is always rejected + (None, ["*"], None, False), + (None, ["http://localhost:3000"], r".*", False), + # Empty allowed list, no regex + ("http://localhost:3000", [], None, False), + # Regex takes precedence when list is empty + ("http://localhost:3000", [], r"http://localhost:\d+", True), + # Both list and regex configured; list matches first + ("http://localhost:3000", ["http://localhost:3000"], r"http://other\.com", True), + ]) +def test_is_origin_allowed(origin, allowed_origins, allow_origin_regex, expected): + assert _is_origin_allowed(origin, allowed_origins, allow_origin_regex) is expected