diff --git a/.env.example b/.env.example index 98ac7d108..6d9b21b3f 100644 --- a/.env.example +++ b/.env.example @@ -84,3 +84,12 @@ OPENAI_API_KEY="" KAAPI_GUARDRAILS_AUTH="" KAAPI_GUARDRAILS_URL="" + +SMTP_HOST= +SMTP_PORT= +SMTP_TLS=True +SMTP_USER= +SMTP_PASSWORD= +EMAILS_FROM_EMAIL= +EMAILS_FROM_NAME=Kaapi +FRONTEND_HOST= diff --git a/backend/app/api/docs/auth/invite_verify.md b/backend/app/api/docs/auth/invite_verify.md new file mode 100644 index 000000000..aaf397b02 --- /dev/null +++ b/backend/app/api/docs/auth/invite_verify.md @@ -0,0 +1,20 @@ +# Verify Invitation + +Verify an invitation token from a magic link email and log the user in. + +## Query Parameters + +- **token** (required): The invitation JWT token from the email link. + +## Behavior + +1. Validates the invitation token (checks signature, expiry, and type). +2. Looks up the user by the email embedded in the token. +3. If the user exists and is inactive (first login), activates the account. +4. Returns a JWT access token with the org/project from the invitation embedded. +5. Sets `access_token` and `refresh_token` as HTTP-only cookies. + +## Error Responses + +- **400**: Invalid or expired invitation link. +- **404**: User account not found. diff --git a/backend/app/api/docs/auth/magic_link.md b/backend/app/api/docs/auth/magic_link.md new file mode 100644 index 000000000..714bb9830 --- /dev/null +++ b/backend/app/api/docs/auth/magic_link.md @@ -0,0 +1,18 @@ +# Request Magic Link Login + +Send a magic link login email to the user's email address. + +## Request Body + +- **email** (required): The user's email address. + +## Behavior + +1. Checks if the user exists — returns 404 if not. +2. Generates a short-lived login token (15 minutes). +3. Sends an email with a "Sign In Now" button linking to the frontend. + +## Error Responses + +- **404**: No account found for this email. +- **500**: Email service is not configured or failed to send. diff --git a/backend/app/api/docs/auth/magic_link_verify.md b/backend/app/api/docs/auth/magic_link_verify.md new file mode 100644 index 000000000..e4654f0d4 --- /dev/null +++ b/backend/app/api/docs/auth/magic_link_verify.md @@ -0,0 +1,20 @@ +# Verify Magic Link + +Verify a magic link login token and log the user in. + +## Query Parameters + +- **token** (required): The login JWT token from the email link. + +## Behavior + +1. Validates the magic link token (checks signature, expiry, and type). +2. Looks up the user by the email embedded in the token. +3. Verifies the user is active. +4. If the user has exactly one project, it is auto-selected and embedded in the JWT. +5. Returns a JWT access token and sets HTTP-only cookies. + +## Error Responses + +- **400**: Invalid or expired login link. +- **404**: User account not found. diff --git a/backend/app/api/docs/credentials/delete_all_by_org_project.md b/backend/app/api/docs/credentials/delete_all_by_org_project.md new file mode 100644 index 000000000..15f8e9cc7 --- /dev/null +++ b/backend/app/api/docs/credentials/delete_all_by_org_project.md @@ -0,0 +1,7 @@ +Delete all credentials for a specific organization and project. + +Permanently removes all provider credentials associated with the specified organization and project IDs. Requires superuser access. + +### Path Parameters: +- **org_id**: Organization ID +- **project_id**: Project ID diff --git a/backend/app/api/docs/credentials/delete_provider_by_org_project.md b/backend/app/api/docs/credentials/delete_provider_by_org_project.md new file mode 100644 index 000000000..ba69b89bd --- /dev/null +++ b/backend/app/api/docs/credentials/delete_provider_by_org_project.md @@ -0,0 +1,8 @@ +Delete credentials for a specific provider within an organization and project. + +Permanently removes credentials for a specific provider from the specified organization and project. Requires superuser access. + +### Path Parameters: +- **org_id**: Organization ID +- **project_id**: Project ID +- **provider**: Provider name (e.g., `openai`, `langfuse`, `google`, `sarvamai`, `elevenlabs`) diff --git a/backend/app/api/docs/credentials/get_provider.md b/backend/app/api/docs/credentials/get_provider.md index 2f3a76920..c7ec981ce 100644 --- a/backend/app/api/docs/credentials/get_provider.md +++ b/backend/app/api/docs/credentials/get_provider.md @@ -1,3 +1,3 @@ Get credentials for a specific provider. -Retrieves decrypted credentials for a specific provider (e.g., `openai`, `langfuse`) for the current organization and project. +Retrieves credentials for a specific provider (e.g., `openai`, `langfuse`) for the current organization and project. Sensitive fields (e.g., `api_key`, `secret_key`) are masked in the response. If credentials for the provider are not configured, `null` is returned. diff --git a/backend/app/api/docs/credentials/get_provider_by_org_project.md b/backend/app/api/docs/credentials/get_provider_by_org_project.md new file mode 100644 index 000000000..accb96fab --- /dev/null +++ b/backend/app/api/docs/credentials/get_provider_by_org_project.md @@ -0,0 +1,8 @@ +Get credentials for a specific provider within an organization and project. + +Retrieves credentials for a specific provider (e.g., `openai`, `langfuse`) for the specified organization and project. Sensitive fields (e.g., `api_key`, `secret_key`) are masked in the response. If credentials for the provider are not configured, `null` is returned. Requires superuser access. + +### Path Parameters: +- **org_id**: Organization ID +- **project_id**: Project ID +- **provider**: Provider name (e.g., `openai`, `langfuse`, `google`, `sarvamai`, `elevenlabs`) diff --git a/backend/app/api/docs/credentials/list.md b/backend/app/api/docs/credentials/list.md index c660229bc..ff0612661 100644 --- a/backend/app/api/docs/credentials/list.md +++ b/backend/app/api/docs/credentials/list.md @@ -1,3 +1,3 @@ Get all credentials for current organization and project. -Returns list of all provider credentials associated with your organization and project. +Returns a list of all provider credentials associated with your organization and project. Sensitive fields (e.g., `api_key`, `secret_key`) are masked in the response. If no credentials are configured, an empty list is returned. diff --git a/backend/app/api/docs/credentials/list_by_org_project.md b/backend/app/api/docs/credentials/list_by_org_project.md new file mode 100644 index 000000000..12dad77e6 --- /dev/null +++ b/backend/app/api/docs/credentials/list_by_org_project.md @@ -0,0 +1,12 @@ +Get all credentials for a specific organization and project. + +Retrieves all provider credentials associated with the specified organization and project IDs. Sensitive fields (e.g., `api_key`, `secret_key`) are masked in the response. If no credentials are configured, an empty list is returned. Requires superuser access. + +### Path Parameters: +- **org_id**: Organization ID +- **project_id**: Project ID + +### Supported Providers: +- **LLM:** openai, sarvamai, google(gemini) +- **Observability:** langfuse +- **Audio:** elevenlabs diff --git a/backend/app/api/docs/credentials/update.md b/backend/app/api/docs/credentials/update.md index 0377f0e4b..cf08360d4 100644 --- a/backend/app/api/docs/credentials/update.md +++ b/backend/app/api/docs/credentials/update.md @@ -1,3 +1,34 @@ Update credentials for a specific provider. -Updates existing provider credentials for the current organization and project. Provider and credential fields must be provided. +Updates existing provider credentials for the current organization and project. If the credentials for the specified provider don't exist yet, they will be **created** automatically (upsert behavior). The `provider` and `credential` fields are required. + +The `credential` field accepts **two formats** (both work the same): + +### Nested format (same as create endpoint): +```json +{ + "provider": "openai", + "is_active": true, + "credential": { + "openai": { + "api_key": "sk-proj-..." + } + } +} +``` + +### Flat format: +```json +{ + "provider": "openai", + "is_active": true, + "credential": { + "api_key": "sk-proj-..." + } +} +``` + +### Supported Providers: +- **LLM:** openai, sarvamai, google(gemini) +- **Observability:** langfuse +- **Audio:** elevenlabs diff --git a/backend/app/api/docs/credentials/update_by_org_project.md b/backend/app/api/docs/credentials/update_by_org_project.md new file mode 100644 index 000000000..c010871d4 --- /dev/null +++ b/backend/app/api/docs/credentials/update_by_org_project.md @@ -0,0 +1,38 @@ +Update credentials for a specific provider within an organization and project. + +Updates existing provider credentials for the specified organization and project. If the credentials for the specified provider don't exist yet, they will be **created** automatically (upsert behavior). Requires superuser access. + +### Path Parameters: +- **org_id**: Organization ID +- **project_id**: Project ID + +The `credential` field accepts **two formats** (both work the same): + +### Nested format (same as create endpoint): +```json +{ + "provider": "openai", + "is_active": true, + "credential": { + "openai": { + "api_key": "sk-proj-..." + } + } +} +``` + +### Flat format: +```json +{ + "provider": "openai", + "is_active": true, + "credential": { + "api_key": "sk-proj-..." + } +} +``` + +### Supported Providers: +- **LLM:** openai, sarvamai, google(gemini) +- **Observability:** langfuse +- **Audio:** elevenlabs diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index a7a7465b1..4c118d73c 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -1,4 +1,5 @@ import logging +from typing import Any from fastapi import APIRouter, HTTPException, Request, status from fastapi.responses import JSONResponse @@ -9,9 +10,12 @@ from app.core.config import settings from app.crud import get_user_by_email from app.crud.auth import get_user_accessible_projects +from app.crud.organization import validate_organization +from app.crud.project import validate_project from app.models import ( GoogleAuthRequest, GoogleAuthResponse, + MagicLinkRequest, Message, SelectProjectRequest, Token, @@ -20,9 +24,17 @@ build_google_auth_response, build_token_response, clear_auth_cookies, + generate_magic_link_token, validate_refresh_token, + verify_invite_token, + verify_magic_link_token, +) +from app.utils import ( + APIResponse, + generate_magic_link_email, + load_description, + send_email, ) -from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -198,3 +210,159 @@ def logout() -> JSONResponse: response = JSONResponse(content=api_response.model_dump()) clear_auth_cookies(response) return response + + +@router.get( + "/invite/verify", + description=load_description("auth/invite_verify.md"), + response_model=APIResponse[Token], +) +def verify_invitation(session: SessionDep, token: str) -> JSONResponse: + """Verify an invitation token, activate the user, and log them in.""" + + invite_payload = verify_invite_token(token) + if not invite_payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation link", + ) + + # Verify the org/project referenced by the invite still exist and are active. + # Raises 404 from validate_* if missing. + validate_organization(session=session, org_id=invite_payload.organization_id) + validate_project(session=session, project_id=invite_payload.project_id) + + user = get_user_by_email(session=session, email=invite_payload.email) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User account not found. Please contact support.", + ) + + # Activate user if not already active + if not user.is_active: + user.is_active = True + session.add(user) + session.commit() + session.refresh(user) + logger.info( + f"[verify_invitation] User activated via invite | user_id: {user.id}" + ) + + response = build_token_response( + user_id=user.id, + organization_id=invite_payload.organization_id, + project_id=invite_payload.project_id, + ) + + logger.info( + f"[verify_invitation] Invitation verified | user_id: {user.id}, project_id: {invite_payload.project_id}" + ) + return response + + +@router.post( + "/magic-link", + description=load_description("auth/magic_link.md"), + response_model=APIResponse[Message], +) +def request_magic_link(session: SessionDep, body: MagicLinkRequest) -> Any: + """Send a magic link login email to the user.""" + + user = get_user_by_email(session=session, email=body.email) + if not user: + logger.info( + f"[request_magic_link] Magic link requested for non-existent email: {body.email}" + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No account found for this email.", + ) + + token = generate_magic_link_token(email=body.email) + + if settings.emails_enabled: + try: + email_data = generate_magic_link_email( + email_to=body.email, + magic_link_token=token, + ) + send_email( + email_to=body.email, + subject=email_data.subject, + html_content=email_data.html_content, + ) + logger.info( + f"[request_magic_link] Magic link email sent | email: {body.email}" + ) + except Exception as e: + logger.error( + f"[request_magic_link] Failed to send magic link email | email: {body.email}, error: {e}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to send login email. Please try again later.", + ) + else: + logger.warning("[request_magic_link] Email sending is not configured") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Email service is not configured", + ) + + return APIResponse.success_response( + data=Message(message="If an account exists, a login link has been sent.") + ) + + +@router.get( + "/magic-link/verify", + description=load_description("auth/magic_link_verify.md"), + response_model=APIResponse[Token], +) +def verify_magic_link(session: SessionDep, token: str) -> JSONResponse: + """Verify a magic link token and log the user in.""" + + email = verify_magic_link_token(token) + if not email: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired login link. Please request a new one.", + ) + + user = get_user_by_email(session=session, email=email) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User account not found", + ) + + # Activate user if not already active + if not user.is_active: + user.is_active = True + session.add(user) + session.commit() + session.refresh(user) + logger.info( + f"[verify_magic_link] User activated via magic link | user_id: {user.id}" + ) + + # Get user's projects to embed in token + available_projects = get_user_accessible_projects(session=session, user_id=user.id) + + organization_id = None + project_id = None + if len(available_projects) == 1: + organization_id = available_projects[0]["organization_id"] + project_id = available_projects[0]["project_id"] + + response = build_token_response( + user_id=user.id, + organization_id=organization_id, + project_id=project_id, + ) + + logger.info( + f"[verify_magic_link] User logged in via magic link | user_id: {user.id}" + ) + return response diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index 8e1e94b41..be75b5b98 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -5,7 +5,7 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission from app.core.exception_handlers import HTTPException -from app.core.providers import validate_provider +from app.core.providers import mask_credential_fields, validate_provider from app.crud.credentials import ( get_creds_by_org, get_provider_credential, @@ -67,15 +67,13 @@ def read_credential( org_id=_current_user.organization_.id, project_id=_current_user.project_.id, ) - if not creds: - raise HTTPException(status_code=404, detail="Credentials not found") return APIResponse.success_response([cred.to_public() for cred in creds]) @router.get( "/provider/{provider}", - response_model=APIResponse[dict], + response_model=APIResponse[dict | None], description=load_description("credentials/get_provider.md"), dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], ) @@ -97,9 +95,11 @@ def read_provider_credential( project_id=_current_user.project_.id, ) if credential is None: - raise HTTPException(status_code=404, detail="Provider credentials not found") + return APIResponse.success_response(None) - return APIResponse.success_response(credential) + return APIResponse.success_response( + mask_credential_fields(provider_enum, credential) + ) @router.patch( @@ -184,3 +184,147 @@ def delete_all_credentials( return APIResponse.success_response( {"message": "All credentials deleted successfully"} ) + + +@router.get( + "/{org_id}/{project_id}", + response_model=APIResponse[list[CredsPublic]], + description=load_description("credentials/list_by_org_project.md"), + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def read_credentials_by_org_project( + *, + session: SessionDep, + org_id: int, + project_id: int, + _current_user: AuthContextDep, +): + creds = get_creds_by_org( + session=session, + org_id=org_id, + project_id=project_id, + ) + + return APIResponse.success_response([cred.to_public() for cred in creds]) + + +@router.get( + "/{org_id}/{project_id}/provider/{provider}", + response_model=APIResponse[dict | None], + description=load_description("credentials/get_provider_by_org_project.md"), + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def read_provider_credential_by_org_project( + *, + session: SessionDep, + org_id: int, + project_id: int, + provider: str, + _current_user: AuthContextDep, +): + try: + provider_enum = validate_provider(provider) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + credential = get_provider_credential( + session=session, + org_id=org_id, + provider=provider_enum, + project_id=project_id, + ) + if credential is None: + return APIResponse.success_response(None) + + return APIResponse.success_response( + mask_credential_fields(provider_enum, credential) + ) + + +@router.patch( + "/{org_id}/{project_id}", + response_model=APIResponse[list[CredsPublic]], + description=load_description("credentials/update_by_org_project.md"), + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def update_credential_by_org_project( + *, + session: SessionDep, + org_id: int, + project_id: int, + creds_in: CredsUpdate, + _current_user: AuthContextDep, +): + if not creds_in or not creds_in.provider or not creds_in.credential: + logger.error( + f"[update_credential_by_org_project] Invalid input | organization_id: {org_id}, project_id: {project_id}" + ) + raise HTTPException( + status_code=400, detail="Provider and credential must be provided" + ) + + updated_credential = update_creds_for_org( + session=session, + org_id=org_id, + creds_in=creds_in, + project_id=project_id, + ) + + return APIResponse.success_response( + [cred.to_public() for cred in updated_credential] + ) + + +@router.delete( + "/{org_id}/{project_id}/provider/{provider}", + response_model=APIResponse[dict], + description=load_description("credentials/delete_provider_by_org_project.md"), + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def delete_provider_credential_by_org_project( + *, + session: SessionDep, + org_id: int, + project_id: int, + provider: str, + _current_user: AuthContextDep, +): + try: + provider_enum = validate_provider(provider) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + remove_provider_credential( + session=session, + org_id=org_id, + provider=provider_enum, + project_id=project_id, + ) + + return APIResponse.success_response( + {"message": "Provider credentials removed successfully"} + ) + + +@router.delete( + "/{org_id}/{project_id}", + response_model=APIResponse[dict], + description=load_description("credentials/delete_all_by_org_project.md"), + dependencies=[Depends(require_permission(Permission.SUPERUSER))], +) +def delete_all_credentials_by_org_project( + *, + session: SessionDep, + org_id: int, + project_id: int, + _current_user: AuthContextDep, +): + remove_creds_for_org( + session=session, + org_id=org_id, + project_id=project_id, + ) + + return APIResponse.success_response( + {"message": "All credentials deleted successfully"} + ) diff --git a/backend/app/api/routes/user_project.py b/backend/app/api/routes/user_project.py index 5052761f7..3da8afdca 100644 --- a/backend/app/api/routes/user_project.py +++ b/backend/app/api/routes/user_project.py @@ -5,6 +5,9 @@ from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission +from app.core.config import settings +from app.crud.organization import get_organization_by_id, validate_organization +from app.crud.project import get_project_by_id, validate_project from app.crud.user_project import ( add_user_to_project, get_users_by_project, @@ -15,7 +18,13 @@ Message, UserProjectPublic, ) -from app.utils import APIResponse, load_description +from app.services.auth import generate_invite_token +from app.utils import ( + APIResponse, + generate_invite_email, + load_description, + send_email, +) logger = logging.getLogger(__name__) @@ -49,6 +58,10 @@ def add_project_users( body: AddUsersToProjectRequest, ) -> Any: """Add one or more users to a project by email.""" + # Validate org and project exist and are active before issuing any invites. + validate_organization(session=session, org_id=body.organization_id) + validate_project(session=session, project_id=body.project_id) + same_project_emails = [] different_project_emails = [] @@ -83,6 +96,37 @@ def add_project_users( session.commit() + # Send invitation emails + organization = get_organization_by_id(session=session, org_id=body.organization_id) + project = get_project_by_id(session=session, project_id=body.project_id) + + if settings.emails_enabled and organization and project: + for entry in body.users: + try: + invite_token = generate_invite_token( + email=str(entry.email), + organization_id=body.organization_id, + project_id=body.project_id, + ) + email_data = generate_invite_email( + email_to=str(entry.email), + project_name=project.name, + organization_name=organization.name, + invite_token=invite_token, + ) + send_email( + email_to=str(entry.email), + subject=email_data.subject, + html_content=email_data.html_content, + ) + logger.info( + f"[add_project_users] Invitation email sent | email: {entry.email}" + ) + except Exception as e: + logger.error( + f"[add_project_users] Failed to send invitation email | email: {entry.email}, error: {e}" + ) + # Re-fetch all users for this project to return the full list results = get_users_by_project(session=session, project_id=body.project_id) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 49cf7e5f6..ec0692997 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -57,6 +57,30 @@ class Settings(BaseSettings): # Google OAuth GOOGLE_CLIENT_ID: str = "" + # Frontend URL for magic links + FRONTEND_HOST: str = "http://localhost:3000" + + # Invitation token expiry (default 24 hours) + INVITE_TOKEN_EXPIRE_HOURS: int = 24 + + # Magic link login token expiry (default 15 minutes) + MAGIC_LINK_TOKEN_EXPIRE_MINUTES: int = 15 + + # SMTP / Email + SMTP_HOST: str = "" + SMTP_PORT: int = 587 + SMTP_USER: str = "" + SMTP_PASSWORD: str = "" + SMTP_TLS: bool = True + SMTP_SSL: bool = False + EMAILS_FROM_EMAIL: str = "" + EMAILS_FROM_NAME: str = "" + + @computed_field # type: ignore[prop-decorator] + @property + def emails_enabled(self) -> bool: + return bool(self.SMTP_HOST and self.EMAILS_FROM_EMAIL) + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 7248ea1df..793995422 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -1,7 +1,7 @@ import logging -from typing import Dict, List +from typing import Any, Dict, List from enum import Enum -from dataclasses import dataclass +from dataclasses import dataclass, field logger = logging.getLogger(__name__) @@ -21,17 +21,27 @@ class ProviderConfig: """Configuration for a provider including its required credential fields.""" required_fields: List[str] + sensitive_fields: List[str] = field(default_factory=list) # Provider configurations PROVIDER_CONFIGS: Dict[Provider, ProviderConfig] = { - Provider.OPENAI: ProviderConfig(required_fields=["api_key"]), + Provider.OPENAI: ProviderConfig( + required_fields=["api_key"], sensitive_fields=["api_key"] + ), Provider.LANGFUSE: ProviderConfig( - required_fields=["secret_key", "public_key", "host"] + required_fields=["secret_key", "public_key", "host"], + sensitive_fields=["secret_key"], + ), + Provider.GOOGLE: ProviderConfig( + required_fields=["api_key"], sensitive_fields=["api_key"] + ), + Provider.SARVAMAI: ProviderConfig( + required_fields=["api_key"], sensitive_fields=["api_key"] + ), + Provider.ELEVENLABS: ProviderConfig( + required_fields=["api_key"], sensitive_fields=["api_key"] ), - Provider.GOOGLE: ProviderConfig(required_fields=["api_key"]), - Provider.SARVAMAI: ProviderConfig(required_fields=["api_key"]), - Provider.ELEVENLABS: ProviderConfig(required_fields=["api_key"]), } @@ -86,3 +96,29 @@ def validate_provider_credentials(provider: str, credentials: Dict[str, str]) -> def get_supported_providers() -> List[str]: """Return a list of all supported provider names.""" return [p.value for p in Provider] + + +def mask_credential_fields( + provider: str, credentials: Dict[str, Any] +) -> Dict[str, Any]: + """Mask sensitive fields in a credential dict for the given provider. + + Non-sensitive fields (e.g., langfuse `public_key`, `host`) are returned as-is. + Unknown providers are returned with no masking. + """ + from app.utils import mask_string + + if not credentials: + return credentials + + try: + provider_enum = Provider(provider.lower()) + except ValueError: + return credentials + + sensitive_fields = PROVIDER_CONFIGS[provider_enum].sensitive_fields + masked = dict(credentials) + for field_name in sensitive_fields: + if field_name in masked and isinstance(masked[field_name], str): + masked[field_name] = mask_string(masked[field_name]) + return masked diff --git a/backend/app/core/security.py b/backend/app/core/security.py index e5f6ac3f8..ef2db7396 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -15,6 +15,7 @@ from typing import Any, Tuple import jwt +from jwt.exceptions import InvalidTokenError from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC @@ -67,58 +68,73 @@ def get_fernet() -> Fernet: return _fernet -def create_access_token( +def encode_jwt_token( subject: str | Any, + token_type: str, expires_delta: timedelta, - organization_id: int | None = None, - project_id: int | None = None, + extra_claims: dict[str, Any] | None = None, ) -> str: + """Encode a JWT with standard `exp`, `nbf`, `sub`, and `type` claims. + + Any additional claims (e.g. `org_id`, `project_id`) can be passed via + `extra_claims` and are merged into the payload before signing. """ - Create a JWT access token. + now = datetime.now(timezone.utc) + to_encode: dict[str, Any] = { + "exp": now + expires_delta, + "nbf": now, + "sub": str(subject), + "type": token_type, + } + if extra_claims: + to_encode.update({k: v for k, v in extra_claims.items() if v is not None}) + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - Args: - subject: The subject of the token (typically user ID) - expires_delta: Token expiration time delta - organization_id: Optional organization ID to embed in the token - project_id: Optional project ID to embed in the token - Returns: - str: Encoded JWT token +def decode_jwt_token( + token: str, expected_type: str | None = None +) -> dict[str, Any] | None: + """Decode and verify a JWT. Returns the payload or None if invalid. + + If `expected_type` is given, the token's `type` claim must match. """ - expire = datetime.now(timezone.utc) + expires_delta - to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "type": "access"} - if organization_id is not None: - to_encode["org_id"] = organization_id - if project_id is not None: - to_encode["project_id"] = project_id - return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + except InvalidTokenError: + return None + if expected_type is not None and payload.get("type") != expected_type: + return None + return payload -def create_refresh_token( +def create_access_token( subject: str | Any, expires_delta: timedelta, organization_id: int | None = None, project_id: int | None = None, ) -> str: - """ - Create a JWT refresh token. + """Create a JWT access token.""" + return encode_jwt_token( + subject=subject, + token_type="access", + expires_delta=expires_delta, + extra_claims={"org_id": organization_id, "project_id": project_id}, + ) - Args: - subject: The subject of the token (typically user ID) - expires_delta: Token expiration time delta - organization_id: Optional organization ID to embed in the token - project_id: Optional project ID to embed in the token - Returns: - str: Encoded JWT refresh token - """ - expire = datetime.now(timezone.utc) + expires_delta - to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "type": "refresh"} - if organization_id is not None: - to_encode["org_id"] = organization_id - if project_id is not None: - to_encode["project_id"] = project_id - return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) +def create_refresh_token( + subject: str | Any, + expires_delta: timedelta, + organization_id: int | None = None, + project_id: int | None = None, +) -> str: + """Create a JWT refresh token.""" + return encode_jwt_token( + subject=subject, + token_type="refresh", + expires_delta=expires_delta, + extra_claims={"org_id": organization_id, "project_id": project_id}, + ) def verify_password(plain_password: str, hashed_password: str) -> bool: diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index e6c1ded6e..6853c455a 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -184,8 +184,18 @@ def update_creds_for_org( if not creds_in.provider or not creds_in.credential: raise ValueError("Provider and credential must be provided") + # Auto-unwrap nested format: {"google": {"api_key": "..."}} -> {"api_key": "..."} + # so the same payload shape works for both create and update. + credential_data = creds_in.credential + if ( + isinstance(credential_data, dict) + and creds_in.provider in credential_data + and isinstance(credential_data[creds_in.provider], dict) + ): + credential_data = credential_data[creds_in.provider] + try: - validate_provider_credentials(creds_in.provider, creds_in.credential) + validate_provider_credentials(creds_in.provider, credential_data) except ValueError as e: logger.error( f"[update_creds_for_org] Validation error | organization_id: {org_id}, project_id: {project_id}, provider: {creds_in.provider}, error: {str(e)}" @@ -193,7 +203,7 @@ def update_creds_for_org( raise HTTPException(status_code=400, detail=str(e)) # Encrypt the entire credentials object - encrypted_credentials = encrypt_credentials(creds_in.credential) + encrypted_credentials = encrypt_credentials(credential_data) statement = select(Credential).where( Credential.organization_id == org_id, @@ -203,12 +213,23 @@ def update_creds_for_org( ) creds = session.exec(statement).one_or_none() if creds is None: - logger.error( - f"[update_creds_for_org] Credentials not found | organization {org_id}, provider {creds_in.provider}, project_id {project_id}" + # Create new credential if it doesn't exist + creds = Credential( + organization_id=org_id, + project_id=project_id, + is_active=creds_in.is_active if creds_in.is_active is not None else True, + provider=creds_in.provider, + credential=encrypted_credentials, + inserted_at=now(), + updated_at=now(), ) - raise HTTPException( - status_code=404, detail="Credentials not found for this provider" + session.add(creds) + session.commit() + session.refresh(creds) + logger.info( + f"[update_creds_for_org] Created new credentials | organization_id {org_id}, provider {creds_in.provider}, project_id {project_id}" ) + return [creds] creds.credential = encrypted_credentials creds.updated_at = now() diff --git a/backend/app/email-templates/build/invite_user.html b/backend/app/email-templates/build/invite_user.html new file mode 100644 index 000000000..6371a63d1 --- /dev/null +++ b/backend/app/email-templates/build/invite_user.html @@ -0,0 +1,52 @@ + + + + + + You're invited to {{ project_name }} + + + + + + +
+ + + + +
+

+ {{ app_name }} +

+

+ {{ organization_name }} +

+ + + + +
+

+ You have been invited to join the {{ project_name }} project on {{ app_name }}. +

+

+ Click the button below to accept the invitation and get started. +

+ + + + +
+ + Accept Invitation + +
+

+ This invitation expires in {{ valid_days }} days.
+ If you did not expect this invitation, you can safely ignore this email. +

+
+
+ + diff --git a/backend/app/email-templates/build/magic_link_login.html b/backend/app/email-templates/build/magic_link_login.html new file mode 100644 index 000000000..daf376d9c --- /dev/null +++ b/backend/app/email-templates/build/magic_link_login.html @@ -0,0 +1,52 @@ + + + + + + Sign in to {{ app_name }} + + + + + + +
+ + + + +
+

+ {{ app_name }} +

+

+ Sign in to your account +

+ + + + +
+

+ We received a sign-in request for {{ email }}. +

+

+ Click the button below to sign in. +

+ + + + +
+ + Sign In Now + +
+

+ This link expires in {{ valid_minutes }} minutes.
+ If you did not request this, you can safely ignore this email. +

+
+
+ + diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 98c4f7d24..05f39032e 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -12,6 +12,8 @@ AuthContext, GoogleAuthRequest, GoogleAuthResponse, + InviteTokenPayload, + MagicLinkRequest, SelectProjectRequest, Token, TokenPayload, diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py index bfbda1f6a..68bf9d89a 100644 --- a/backend/app/models/auth.py +++ b/backend/app/models/auth.py @@ -19,6 +19,13 @@ class TokenPayload(SQLModel): type: str = "access" +# Payload returned after verifying an invite JWT +class InviteTokenPayload(SQLModel): + email: str + organization_id: int + project_id: int + + # Google OAuth class GoogleAuthRequest(SQLModel): token: str @@ -37,6 +44,10 @@ class SelectProjectRequest(SQLModel): project_id: int +class MagicLinkRequest(SQLModel): + email: str + + class AuthContext(SQLModel): user: User organization: Organization | None = None diff --git a/backend/app/models/credentials.py b/backend/app/models/credentials.py index b295927c2..e6741b2e9 100644 --- a/backend/app/models/credentials.py +++ b/backend/app/models/credentials.py @@ -4,6 +4,7 @@ import sqlalchemy as sa from sqlmodel import Field, Relationship, SQLModel +from app.core.providers import mask_credential_fields from app.core.util import now from app.models.organization import Organization from app.models.project import Project @@ -113,19 +114,26 @@ class Credential(CredsBase, table=True): organization: Organization | None = Relationship(back_populates="creds") project: Project | None = Relationship(back_populates="creds") - def to_public(self) -> "CredsPublic": - """Convert the database model to a public model with decrypted credentials.""" + def to_public(self, mask: bool = True) -> "CredsPublic": + """Convert the database model to a public model with decrypted credentials. + + By default, sensitive fields (e.g., api_key, secret_key) are masked so + the response is safe to return via the API. + """ + # Local import to avoid circular dependency (security imports app.models) from app.core.security import decrypt_credentials + decrypted = decrypt_credentials(self.credential) if self.credential else None + if mask and decrypted: + decrypted = mask_credential_fields(self.provider, decrypted) + return CredsPublic( id=self.id, organization_id=self.organization_id, project_id=self.project_id, is_active=self.is_active, provider=self.provider, - credential=decrypt_credentials(self.credential) - if self.credential - else None, + credential=decrypted, inserted_at=self.inserted_at, updated_at=self.updated_at, ) diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py index 5c266092e..fddd8dde5 100644 --- a/backend/app/services/auth.py +++ b/backend/app/services/auth.py @@ -11,6 +11,7 @@ from app.core.config import settings from app.models import ( GoogleAuthResponse, + InviteTokenPayload, Token, TokenPayload, User, @@ -176,3 +177,81 @@ def validate_refresh_token( raise HTTPException(status_code=403, detail="Inactive user") return user, token_data + + +def generate_email_token( + email: str, + token_type: str, + expires_delta: timedelta, + organization_id: int | None = None, + project_id: int | None = None, +) -> str: + """Generate a JWT email token (invite or magic_link). + + Args: + email: User's email address (stored as 'sub' claim) + token_type: Token type identifier (e.g. "invite", "magic_link") + expires_delta: Token expiration duration + organization_id: Optional org ID to embed (for invite tokens) + project_id: Optional project ID to embed (for invite tokens) + """ + return security.encode_jwt_token( + subject=email, + token_type=token_type, + expires_delta=expires_delta, + extra_claims={"org_id": organization_id, "project_id": project_id}, + ) + + +def verify_email_token(token: str, expected_type: str) -> dict | None: + """Verify a JWT email token and return its claims as a dict, or None if invalid.""" + payload = security.decode_jwt_token(token, expected_type=expected_type) + if not payload or "sub" not in payload: + return None + return { + "email": payload["sub"], + "organization_id": payload.get("org_id"), + "project_id": payload.get("project_id"), + } + + +def generate_invite_token(email: str, organization_id: int, project_id: int) -> str: + """Generate a JWT invitation token for a user (expires in INVITE_TOKEN_EXPIRE_HOURS).""" + return generate_email_token( + email=email, + token_type="invite", + expires_delta=timedelta(hours=settings.INVITE_TOKEN_EXPIRE_HOURS), + organization_id=organization_id, + project_id=project_id, + ) + + +def verify_invite_token(token: str) -> InviteTokenPayload | None: + """Verify an invitation token and return its payload, or None if invalid.""" + claims = verify_email_token(token, expected_type="invite") + if ( + not claims + or claims.get("organization_id") is None + or claims.get("project_id") is None + ): + return None + return InviteTokenPayload( + email=claims["email"], + organization_id=claims["organization_id"], + project_id=claims["project_id"], + ) + + +def generate_magic_link_token(email: str) -> str: + """Generate a short-lived magic link login token (expires in MAGIC_LINK_TOKEN_EXPIRE_MINUTES).""" + return generate_email_token( + email=email, + token_type="magic_link", + expires_delta=timedelta(minutes=settings.MAGIC_LINK_TOKEN_EXPIRE_MINUTES), + ) + + +def verify_magic_link_token(token: str) -> str | None: + """Verify a magic link token. Returns email string or None.""" + result = verify_email_token(token, expected_type="magic_link") + return result["email"] if result else None diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 9feffa19d..39f164ecd 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -50,7 +50,7 @@ def test_read_provider_credential_not_found( client: TestClient, user_api_key: TestAuthContext, ) -> None: - """Test reading credentials for non-existent provider returns 404.""" + """Test reading credentials for non-existent provider returns null.""" client.delete( f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", headers={"X-API-KEY": user_api_key.key}, @@ -61,8 +61,8 @@ def test_read_provider_credential_not_found( headers={"X-API-KEY": user_api_key.key}, ) - assert response.status_code == 404 - assert "Provider credentials not found" in response.json()["error"] + assert response.status_code == 200 + assert response.json()["data"] is None def test_update_credentials( @@ -100,7 +100,11 @@ def test_update_credentials( ) assert verify_response.status_code == 200 verify_data = verify_response.json().get("data", verify_response.json()) - assert verify_data["api_key"] == new_api_key + # Sensitive fields are masked in GET responses + assert verify_data["api_key"] != new_api_key + assert "*" in verify_data["api_key"] + assert verify_data["api_key"].startswith("sk-") + assert verify_data["api_key"].endswith(new_api_key[-4:]) def test_create_credential( @@ -138,7 +142,11 @@ def test_create_credential( data = create_response.json()["data"] assert len(data) == 1 assert data[0]["provider"] == Provider.OPENAI.value - assert data[0]["credential"]["api_key"] == api_key + # Sensitive fields are masked in API responses + assert data[0]["credential"]["api_key"] != api_key + assert "*" in data[0]["credential"]["api_key"] + assert data[0]["credential"]["api_key"].startswith("sk-") + assert data[0]["credential"]["api_key"].endswith(api_key[-4:]) def test_credential_encryption( @@ -163,12 +171,12 @@ def test_credential_encryption( assert decrypted_creds["api_key"].startswith("sk-") -def test_update_nonexistent_provider_returns_404( +def test_update_nonexistent_provider_upserts( client: TestClient, user_api_key: TestAuthContext, ) -> None: - """Test updating credentials for non-existent provider.""" - # Delete OpenAI first + """Test that updating a non-existent provider creates it (upsert behavior).""" + # Delete OpenAI first so no credential exists for the provider client.delete( f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", headers={"X-API-KEY": user_api_key.key}, @@ -188,8 +196,10 @@ def test_update_nonexistent_provider_returns_404( headers={"X-API-KEY": user_api_key.key}, ) - assert response.status_code == 404 - assert "Credentials not found" in response.json()["error"] + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["provider"] == Provider.OPENAI.value def test_create_ignores_mismatched_ids( @@ -269,7 +279,8 @@ def test_delete_all_credentials( f"{settings.API_V1_STR}/credentials/", headers={"X-API-KEY": user_api_key.key}, ) - assert get_response.status_code == 404 + assert get_response.status_code == 200 + assert get_response.json()["data"] == [] def test_delete_all_when_none_exist_returns_404( @@ -314,12 +325,13 @@ def test_delete_provider_credential( response_data["data"]["message"] == "Provider credentials removed successfully" ) - # Verify it's deleted + # Verify it's deleted (endpoint returns null when provider credential is missing) verify_response = client.get( f"{settings.API_V1_STR}/credentials/provider/{Provider.OPENAI.value}", headers={"X-API-KEY": user_api_key.key}, ) - assert verify_response.status_code == 404 + assert verify_response.status_code == 200 + assert verify_response.json()["data"] is None def test_delete_provider_credential_not_found( @@ -474,11 +486,11 @@ def test_update_credential_empty_credential( assert response.status_code in [200, 400] # Depends on implementation -def test_read_credentials_not_found( +def test_read_credentials_when_none_exist( client: TestClient, user_api_key: TestAuthContext, ) -> None: - """Test reading credentials when none exist.""" + """Test reading credentials when none exist returns an empty list.""" # Delete all credentials first client.delete( f"{settings.API_V1_STR}/credentials/", @@ -490,8 +502,8 @@ def test_read_credentials_not_found( headers={"X-API-KEY": user_api_key.key}, ) - assert response.status_code == 404 - assert "Credentials not found" in response.json()["error"] + assert response.status_code == 200 + assert response.json()["data"] == [] def test_create_multiple_providers_at_once( diff --git a/backend/app/tests/api/test_auth.py b/backend/app/tests/api/test_auth.py index 1ecafcd3c..cfa05dffb 100644 --- a/backend/app/tests/api/test_auth.py +++ b/backend/app/tests/api/test_auth.py @@ -6,6 +6,12 @@ from app.core.config import settings from app.core.security import create_access_token, create_refresh_token +from app.services.auth import ( + generate_invite_token, + generate_magic_link_token, + verify_invite_token, + verify_magic_link_token, +) from app.tests.utils.auth import TestAuthContext from app.tests.utils.user import create_random_user @@ -13,6 +19,9 @@ SELECT_PROJECT_URL = f"{settings.API_V1_STR}/auth/select-project" REFRESH_URL = f"{settings.API_V1_STR}/auth/refresh" LOGOUT_URL = f"{settings.API_V1_STR}/auth/logout" +MAGIC_LINK_URL = f"{settings.API_V1_STR}/auth/magic-link" +MAGIC_LINK_VERIFY_URL = f"{settings.API_V1_STR}/auth/magic-link/verify" +INVITE_VERIFY_URL = f"{settings.API_V1_STR}/auth/invite/verify" MOCK_GOOGLE_PROFILE = { "email": None, # set per test @@ -107,9 +116,6 @@ def test_google_auth_activates_inactive_user( resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) assert resp.status_code == 200 - db.refresh(user) - assert user.is_active is True - @patch("app.api.routes.auth.id_token.verify_oauth2_token") @patch("app.api.routes.auth.settings") def test_google_auth_success_no_projects( @@ -298,3 +304,212 @@ def test_logout_success(self, client: TestClient): body = resp.json() assert body["success"] is True assert body["data"]["message"] == "Logged out successfully" + + +class TestMagicLink: + """Test suite for POST /auth/magic-link endpoint.""" + + @patch("app.api.routes.auth.settings") + def test_magic_link_email_not_configured(self, mock_settings, client: TestClient): + """Test returns 500 when email is not configured.""" + mock_settings.emails_enabled = False + resp = client.post(MAGIC_LINK_URL, json={"email": "test@example.com"}) + assert resp.status_code == 500 + + def test_magic_link_nonexistent_user(self, client: TestClient): + """Test returns 404 for non-existent user.""" + resp = client.post(MAGIC_LINK_URL, json={"email": "nonexistent@example.com"}) + assert resp.status_code == 404 + assert "No account found" in resp.json()["error"] + + @patch("app.api.routes.auth.send_email") + @patch("app.api.routes.auth.settings") + def test_magic_link_inactive_user_allowed( + self, mock_settings, mock_send, db: Session, client: TestClient + ): + """Test inactive user can still request magic link to reactivate.""" + user = create_random_user(db) + user.is_active = False + db.add(user) + db.commit() + + mock_settings.emails_enabled = True + mock_settings.MAGIC_LINK_TOKEN_EXPIRE_MINUTES = 15 + mock_settings.SECRET_KEY = settings.SECRET_KEY + mock_settings.FRONTEND_HOST = "http://localhost:3000" + mock_settings.PROJECT_NAME = "Kaapi" + + resp = client.post(MAGIC_LINK_URL, json={"email": user.email}) + assert resp.status_code == 200 + mock_send.assert_called_once() + + @patch("app.api.routes.auth.send_email") + @patch("app.api.routes.auth.settings") + def test_magic_link_success( + self, mock_settings, mock_send, db: Session, client: TestClient + ): + """Test sends email for valid active user.""" + user = create_random_user(db) + + mock_settings.emails_enabled = True + mock_settings.MAGIC_LINK_TOKEN_EXPIRE_MINUTES = 15 + mock_settings.SECRET_KEY = settings.SECRET_KEY + mock_settings.FRONTEND_HOST = "http://localhost:3000" + mock_settings.PROJECT_NAME = "Kaapi" + + resp = client.post(MAGIC_LINK_URL, json={"email": user.email}) + assert resp.status_code == 200 + assert "login link has been sent" in resp.json()["data"]["message"] + mock_send.assert_called_once() + + +class TestMagicLinkVerify: + """Test suite for GET /auth/magic-link/verify endpoint.""" + + def test_verify_invalid_token(self, client: TestClient): + """Test returns 400 for invalid token.""" + resp = client.get(f"{MAGIC_LINK_VERIFY_URL}?token=invalid.token.here") + assert resp.status_code == 400 + assert "expired" in resp.json()["error"] or "Invalid" in resp.json()["error"] + + def test_verify_expired_token(self, db: Session, client: TestClient): + """Test returns 400 for expired magic link token.""" + user = create_random_user(db) + with patch("app.services.auth.settings.MAGIC_LINK_TOKEN_EXPIRE_MINUTES", -1): + token = generate_magic_link_token(email=user.email) + + resp = client.get(f"{MAGIC_LINK_VERIFY_URL}?token={token}") + assert resp.status_code == 400 + + def test_verify_user_not_found(self, client: TestClient): + """Test returns 404 when user doesn't exist.""" + token = generate_magic_link_token(email="ghost@example.com") + resp = client.get(f"{MAGIC_LINK_VERIFY_URL}?token={token}") + assert resp.status_code == 404 + + def test_verify_activates_inactive_user(self, db: Session, client: TestClient): + """Test magic link verify activates inactive user.""" + user = create_random_user(db) + user.is_active = False + db.add(user) + db.commit() + db.refresh(user) + + token = generate_magic_link_token(email=user.email) + resp = client.get(f"{MAGIC_LINK_VERIFY_URL}?token={token}") + assert resp.status_code == 200 + + db.refresh(user) + assert user.is_active is True + + def test_verify_success(self, db: Session, client: TestClient): + """Test successful magic link verification logs user in.""" + user = create_random_user(db) + token = generate_magic_link_token(email=user.email) + + resp = client.get(f"{MAGIC_LINK_VERIFY_URL}?token={token}") + assert resp.status_code == 200 + assert resp.json()["success"] is True + assert "access_token" in resp.json()["data"] + assert "access_token" in resp.cookies + + +class TestInviteVerify: + """Test suite for GET /auth/invite/verify endpoint.""" + + def test_verify_invalid_token(self, client: TestClient): + """Test returns 400 for invalid invite token.""" + resp = client.get(f"{INVITE_VERIFY_URL}?token=invalid.token") + assert resp.status_code == 400 + + def test_verify_user_not_found( + self, client: TestClient, user_api_key: TestAuthContext + ): + """Test returns 404 when invited user doesn't exist.""" + token = generate_invite_token( + email="ghost@example.com", + organization_id=user_api_key.organization.id, + project_id=user_api_key.project.id, + ) + resp = client.get(f"{INVITE_VERIFY_URL}?token={token}") + assert resp.status_code == 404 + + def test_verify_activates_inactive_user( + self, db: Session, client: TestClient, user_api_key: TestAuthContext + ): + """Test invite verification activates inactive user.""" + user = create_random_user(db) + user.is_active = False + db.add(user) + db.commit() + db.refresh(user) + + token = generate_invite_token( + email=user.email, + organization_id=user_api_key.organization.id, + project_id=user_api_key.project.id, + ) + resp = client.get(f"{INVITE_VERIFY_URL}?token={token}") + assert resp.status_code == 200 + + db.refresh(user) + assert user.is_active is True + assert "access_token" in resp.json()["data"] + + def test_verify_success_active_user( + self, db: Session, client: TestClient, user_api_key: TestAuthContext + ): + """Test invite verification works for already active user.""" + user = create_random_user(db) + token = generate_invite_token( + email=user.email, + organization_id=user_api_key.organization.id, + project_id=user_api_key.project.id, + ) + resp = client.get(f"{INVITE_VERIFY_URL}?token={token}") + assert resp.status_code == 200 + assert "access_token" in resp.json()["data"] + + +class TestTokenGeneration: + """Test suite for services/auth.py token generation functions.""" + + def test_generate_and_verify_invite_token(self): + """Test invite token roundtrip.""" + token = generate_invite_token( + email="test@example.com", organization_id=1, project_id=2 + ) + result = verify_invite_token(token) + assert result is not None + assert result.email == "test@example.com" + assert result.organization_id == 1 + assert result.project_id == 2 + + def test_verify_invite_token_wrong_type(self): + """Test invite verify rejects magic_link tokens.""" + token = generate_magic_link_token(email="test@example.com") + result = verify_invite_token(token) + assert result is None + + def test_generate_and_verify_magic_link_token(self): + """Test magic link token roundtrip.""" + token = generate_magic_link_token(email="test@example.com") + result = verify_magic_link_token(token) + assert result == "test@example.com" + + def test_verify_magic_link_token_wrong_type(self): + """Test magic link verify rejects invite tokens.""" + token = generate_invite_token( + email="test@example.com", organization_id=1, project_id=1 + ) + result = verify_magic_link_token(token) + assert result is None + + def test_verify_invalid_token_returns_none(self): + """Test both verify functions return None for garbage tokens.""" + assert verify_invite_token("garbage") is None + assert verify_magic_link_token("garbage") is None + + def test_verify_invite_token_invalid(self): + """Test invite verify returns None for garbage tokens.""" + assert verify_invite_token("garbage") is None diff --git a/backend/app/tests/api/test_user_project.py b/backend/app/tests/api/test_user_project.py index b984fd060..dd4f6cceb 100644 --- a/backend/app/tests/api/test_user_project.py +++ b/backend/app/tests/api/test_user_project.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + from fastapi.testclient import TestClient from sqlmodel import Session @@ -101,6 +103,38 @@ def test_add_single_user( emails = [u["email"] for u in data] assert email in emails + @patch("app.api.routes.user_project.send_email") + @patch("app.api.routes.user_project.settings") + def test_add_user_sends_invite_email( + self, + mock_settings, + mock_send_email, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test adding a user sends an invitation email when emails are enabled.""" + project = create_test_project(db) + email = random_email() + + mock_settings.emails_enabled = True + mock_settings.INVITE_TOKEN_EXPIRE_HOURS = 168 + mock_settings.SECRET_KEY = settings.SECRET_KEY + mock_settings.FRONTEND_HOST = "http://localhost:3000" + mock_settings.PROJECT_NAME = "Kaapi" + + resp = client.post( + f"{USER_PROJECTS_URL}/", + json={ + "organization_id": project.organization_id, + "project_id": project.id, + "users": [{"email": email}], + }, + headers=superuser_token_headers, + ) + assert resp.status_code == 201 + mock_send_email.assert_called_once() + def test_add_duplicate_user_same_project( self, db: Session, diff --git a/backend/app/utils.py b/backend/app/utils.py index 353be2152..0d9741a5f 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -6,7 +6,7 @@ import logging import tempfile from dataclasses import dataclass -from datetime import datetime, timedelta, timezone +from datetime import timedelta from pathlib import Path import requests import socket @@ -14,10 +14,8 @@ from typing import Any, Dict, Generic, Optional, TypeVar from urllib.parse import urlparse -import jwt import emails from jinja2 import Template -from jwt.exceptions import InvalidTokenError from fastapi import HTTPException from langfuse import Langfuse import openai @@ -184,27 +182,56 @@ def generate_new_account_email( return EmailData(html_content=html_content, subject=subject) +def generate_invite_email( + *, + email_to: str, + project_name: str, + organization_name: str, + invite_token: str, +) -> EmailData: + app_name = settings.PROJECT_NAME + subject = f"{app_name} - You've been invited to {project_name}" + link = f"{settings.FRONTEND_HOST}/invite?token={invite_token}" + html_content = render_email_template( + template_name="invite_user.html", + context={ + "app_name": app_name, + "project_name": project_name, + "organization_name": organization_name, + "link": link, + "valid_days": settings.INVITE_TOKEN_EXPIRE_HOURS // 24, + }, + ) + return EmailData(html_content=html_content, subject=subject) + + +def generate_magic_link_email(*, email_to: str, magic_link_token: str) -> EmailData: + app_name = settings.PROJECT_NAME + subject = f"{app_name} - Sign in to your account" + link = f"{settings.FRONTEND_HOST}/verify?token={magic_link_token}" + html_content = render_email_template( + template_name="magic_link_login.html", + context={ + "app_name": app_name, + "email": email_to, + "link": link, + "valid_minutes": settings.MAGIC_LINK_TOKEN_EXPIRE_MINUTES, + }, + ) + return EmailData(html_content=html_content, subject=subject) + + def generate_password_reset_token(email: str) -> str: - delta = timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS) - now = datetime.now(timezone.utc) - expires = now + delta - exp = expires.timestamp() - encoded_jwt = jwt.encode( - {"exp": exp, "nbf": now, "sub": email}, - settings.SECRET_KEY, - algorithm=security.ALGORITHM, + return security.encode_jwt_token( + subject=email, + token_type="password_reset", + expires_delta=timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS), ) - return encoded_jwt def verify_password_reset_token(token: str) -> str | None: - try: - decoded_token = jwt.decode( - token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] - ) - return str(decoded_token["sub"]) - except InvalidTokenError: - return None + payload = security.decode_jwt_token(token, expected_type="password_reset") + return str(payload["sub"]) if payload and "sub" in payload else None def mask_string(value: str, mask_char: str = "*") -> str: