diff --git a/.env.example b/.env.example index 3e184d8..7383309 100644 --- a/.env.example +++ b/.env.example @@ -3,6 +3,10 @@ # Copy this file to .env and fill in your values. # ───────────────────────────────────────────────── +# Comma-separated frontend origins allowed by browser CORS checks. +# Example: CORS_ORIGINS=http://localhost:3000,https://app.example.com +# CORS_ORIGINS= + # ── LLM Provider ──────────────────────────────── LLM_PROVIDER=google GEMINI_API_KEY= diff --git a/README.md b/README.md index a8bdd98..8944252 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ docker compose up -d Fill in `.env` (see the [Configuration](#configuration) section below). The API is available at `http://localhost:8000`. Interactive docs at `http://localhost:8000/api/v1/docs`. +If you are calling the API from a browser-based frontend on another origin, set `CORS_ORIGINS` to the allowed frontend origins. + ### Ingest and query ```bash @@ -93,6 +95,7 @@ All settings are environment variables loaded from `.env` via `pydantic-settings | Variable | Default | Description | | -------------------------- | ------------------------------------- | ------------------------------------------------------------------ | +| `CORS_ORIGINS` | (none) | Comma-separated frontend origins allowed by browser CORS | | `ORG_NAME` | `MicroClub` | Organization name embedded in the AI system prompt | | `ORG_DESCRIPTION` | `A generic organization using mAIcro` | Organization description | | `GOOGLE_MODEL_NAME` | `gemini-2.5-flash` | Gemini model used for answering | diff --git a/src/core/config.py b/src/core/config.py index b140515..826457c 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -6,6 +6,7 @@ class Settings(BaseSettings): PROJECT_NAME: str = "mAIcro" VERSION: str = "0.1.0" API_V1_STR: str = "/api/v1" + CORS_ORIGINS: Optional[str] = None ORG_NAME: str = "MicroClub" ORG_DESCRIPTION: Optional[str] = "A generic organization using mAIcro" @@ -50,6 +51,17 @@ def discord_channel_id_list(self) -> List[str]: cid.strip() for cid in self.DISCORD_CHANNEL_IDS.split(",") if cid.strip() ] + @property + def cors_origin_list(self) -> List[str]: + """Parse comma-separated CORS origins from CORS_ORIGINS.""" + if not self.CORS_ORIGINS: + return [] + + return [ + origin.strip() + for origin in self.CORS_ORIGINS.split(",") + if origin.strip() + ] -settings = Settings() +settings = Settings() diff --git a/src/main.py b/src/main.py index 294c615..e0176f1 100644 --- a/src/main.py +++ b/src/main.py @@ -4,6 +4,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from api.error_handlers import register_exception_handlers from api.routes import router @@ -55,6 +56,14 @@ async def lifespan(app: FastAPI): redoc_url=f"{settings.API_V1_STR}/redoc", ) +if settings.cors_origin_list: + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origin_list, + allow_credentials=False, + allow_methods=["*"], + allow_headers=["*"], + ) app.include_router(router, prefix=settings.API_V1_STR) register_exception_handlers(app) diff --git a/tests/api/test_routes.py b/tests/api/test_routes.py index 4c957c1..177b83a 100644 --- a/tests/api/test_routes.py +++ b/tests/api/test_routes.py @@ -1,21 +1,29 @@ import asyncio +import importlib import httpx import api.routes as routes +import main as main_module +from core.config import settings from services.qa_service import AskError from main import app -def request(method: str, path: str, **kwargs): +def request(method: str, path: str, app_instance=None, **kwargs): async def _send(): - transport = httpx.ASGITransport(app=app) + transport = httpx.ASGITransport(app=app_instance or app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: return await client.request(method, path, **kwargs) return asyncio.run(_send()) +def build_app_with_cors(monkeypatch, origins: str): + monkeypatch.setattr(settings, "CORS_ORIGINS", origins) + return importlib.reload(main_module).app + + def test_health_endpoint_returns_ok(): res = request("GET", "/api/v1/health") @@ -86,3 +94,41 @@ async def _fake_ingest(): assert body["documents_ingested"] == 5 assert body["details"]["channels"] == {"123": 5} assert body["details"]["errors"] == {"456": "missing access"} + + +def test_cors_preflight_allows_configured_frontend_origin(monkeypatch): + cors_app = build_app_with_cors( + monkeypatch, "http://localhost:3000,https://app.example.com" + ) + + res = request( + "OPTIONS", + "/api/v1/ask", + app_instance=cors_app, + headers={ + "Origin": "http://localhost:3000", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "content-type", + }, + ) + + assert res.status_code == 200 + assert res.headers["access-control-allow-origin"] == "http://localhost:3000" + assert "POST" in res.headers["access-control-allow-methods"] + + +def test_cors_preflight_rejects_unconfigured_origin(monkeypatch): + cors_app = build_app_with_cors(monkeypatch, "http://localhost:3000") + + res = request( + "OPTIONS", + "/api/v1/ask", + app_instance=cors_app, + headers={ + "Origin": "http://localhost:5173", + "Access-Control-Request-Method": "POST", + }, + ) + + assert res.status_code == 400 + assert "access-control-allow-origin" not in res.headers diff --git a/tests/conftest.py b/tests/conftest.py index dc7a40f..5bd500a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ import pytest from core.config import settings + @pytest.fixture(autouse=True) def mock_env_vars(monkeypatch): """Provide dummy environment variables for all tests to bypass strict Cloud-only checks.""" + monkeypatch.setattr(settings, "CORS_ORIGINS", None) monkeypatch.setattr(settings, "QDRANT_URL", "https://dummy.qdrant.io:6333") monkeypatch.setattr(settings, "QDRANT_API_KEY", "dummy-api-key") monkeypatch.setattr(settings, "GEMINI_API_KEY", "dummy-gemini-key")