diff --git a/.env.example b/.env.example index 8df496766..abea3f771 100644 --- a/.env.example +++ b/.env.example @@ -43,7 +43,7 @@ DOCKER_IMAGE_FRONTEND=frontend AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_DEFAULT_REGION=ap-south-1 -AWS_S3_BUCKET_PREFIX = "bucket-prefix-name" +AWS_S3_BUCKET_PREFIX="bucket-prefix-name" # OpenAI diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index 68e4b6a87..8682cb69f 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -1,26 +1,28 @@ import inspect import logging import time -import warnings from uuid import UUID, uuid4 from typing import Any, List, Optional from dataclasses import dataclass, field, fields, asdict, replace -from openai import OpenAI, OpenAIError +from openai import OpenAIError, OpenAI from fastapi import APIRouter, HTTPException, BackgroundTasks, Query from fastapi import Path as FastPath from pydantic import BaseModel, Field, HttpUrl -from sqlalchemy.exc import NoResultFound, MultipleResultsFound, SQLAlchemyError +from sqlalchemy.exc import SQLAlchemyError from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject from app.core.cloud import AmazonCloudStorage -from app.core.config import settings -from app.core.util import now, raise_from_unknown, post_callback -from app.crud import DocumentCrud, CollectionCrud, DocumentCollectionCrud +from app.core.util import now, post_callback +from app.crud import ( + DocumentCrud, + CollectionCrud, + DocumentCollectionCrud, +) from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud from app.models import Collection, Document from app.models.collection import CollectionStatus -from app.utils import APIResponse, load_description +from app.utils import APIResponse, load_description, get_openai_client logger = logging.getLogger(__name__) router = APIRouter(prefix="/collections", tags=["collections"]) @@ -180,12 +182,13 @@ def _backout(crud: OpenAIAssistantCrud, assistant_id: str): def do_create_collection( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, request: CreationRequest, payload: ResponsePayload, + client: OpenAI, ): start_time = time.time() - client = OpenAI(api_key=settings.OPENAI_API_KEY) + callback = ( SilentCallback(payload) if request.callback_url is None @@ -226,7 +229,7 @@ def do_create_collection( collection_crud._update(collection) elapsed = time.time() - start_time - logging.info( + logger.info( f"[do_create_collection] Collection created: {collection.id} | Time: {elapsed:.2f}s | " f"Files: {len(flat_docs)} | Sizes: {file_sizes_kb} KB | Types: {list(file_exts)}" ) @@ -261,6 +264,10 @@ def create_collection( request: CreationRequest, background_tasks: BackgroundTasks, ): + client = get_openai_client( + session, current_user.organization_id, current_user.project_id + ) + this = inspect.currentframe() route = router.url_path_for(this.f_code.co_name) payload = ResponsePayload("processing", route) @@ -278,11 +285,7 @@ def create_collection( # 2. Launch background task background_tasks.add_task( - do_create_collection, - session, - current_user, - request, - payload, + do_create_collection, session, current_user, request, payload, client ) logger.info( @@ -294,9 +297,10 @@ def create_collection( def do_delete_collection( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, request: DeletionRequest, payload: ResponsePayload, + client: OpenAI, ): if request.callback_url is None: callback = SilentCallback(payload) @@ -306,7 +310,7 @@ def do_delete_collection( collection_crud = CollectionCrud(session, current_user.id) try: collection = collection_crud.read_one(request.collection_id) - assistant = OpenAIAssistantCrud() + assistant = OpenAIAssistantCrud(client) data = collection_crud.delete(collection, assistant) logger.info( f"[do_delete_collection] Collection deleted successfully | {{'collection_id': '{collection.id}'}}" @@ -332,20 +336,20 @@ def do_delete_collection( ) def delete_collection( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, request: DeletionRequest, background_tasks: BackgroundTasks, ): + client = get_openai_client( + session, current_user.organization_id, current_user.project_id + ) + this = inspect.currentframe() route = router.url_path_for(this.f_code.co_name) payload = ResponsePayload("processing", route) background_tasks.add_task( - do_delete_collection, - session, - current_user, - request, - payload, + do_delete_collection, session, current_user, request, payload, client ) logger.info( diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index 03d2cd98e..3924cabf5 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -8,8 +8,8 @@ from app.crud import DocumentCrud, CollectionCrud from app.models import Document -from app.utils import APIResponse, load_description -from app.api.deps import CurrentUser, SessionDep +from app.utils import APIResponse, load_description, get_openai_client +from app.api.deps import CurrentUser, SessionDep, CurrentUserOrgProject from app.core.cloud import AmazonCloudStorage from app.crud.rag import OpenAIAssistantCrud @@ -65,10 +65,14 @@ def upload_doc( ) def remove_doc( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, doc_id: UUID = FastPath(description="Document to delete"), ): - a_crud = OpenAIAssistantCrud() + client = get_openai_client( + session, current_user.organization_id, current_user.project_id + ) + + a_crud = OpenAIAssistantCrud(client) d_crud = DocumentCrud(session, current_user.id) c_crud = CollectionCrud(session, current_user.id) @@ -84,10 +88,14 @@ def remove_doc( ) def permanent_delete_doc( session: SessionDep, - current_user: CurrentUser, + current_user: CurrentUserOrgProject, doc_id: UUID = FastPath(description="Document to permanently delete"), ): - a_crud = OpenAIAssistantCrud() + client = get_openai_client( + session, current_user.organization_id, current_user.project_id + ) + + a_crud = OpenAIAssistantCrud(client) d_crud = DocumentCrud(session, current_user.id) c_crud = CollectionCrud(session, current_user.id) storage = AmazonCloudStorage(current_user) diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index 69479888f..cdb644abc 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -90,8 +90,12 @@ def clean(self, resource): class OpenAICrud: - def __init__(self, client=None): - self.client = client or OpenAI(api_key=settings.OPENAI_API_KEY) + def __init__(self, client): + if client is None: + logger.error("[OpenAICrud] OpenAI client is not configured") + raise ValueError("OpenAI client is not configured") + + self.client = client class OpenAIVectorStoreCrud(OpenAICrud): diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index f28658866..5747f7905 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -4,12 +4,9 @@ from sqlmodel import Session from app.core.config import settings from app.models import Collection -from app.main import app from app.tests.utils.utils import get_user_from_api_key from app.models.collection import CollectionStatus -client = TestClient(app) - def create_collection( db, diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index ef779211e..2bede71aa 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -2,18 +2,16 @@ from uuid import UUID import io -import openai_responses from sqlmodel import Session from fastapi.testclient import TestClient +from unittest.mock import patch from app.core.config import settings from app.tests.utils.document import DocumentStore -from app.tests.utils.utils import openai_credentials, get_user_from_api_key -from app.main import app +from app.tests.utils.utils import get_user_from_api_key from app.crud.collection import CollectionCrud from app.models.collection import CollectionStatus - -client = TestClient(app) +from app.tests.utils.openai import get_mock_openai_client_with_vector_store @pytest.fixture(autouse=True) @@ -31,7 +29,7 @@ def stream(self, file_obj): return fake_file def get_file_size_kb(self, url: str) -> float: - return 1.0 # Simulate 1KB files + return 1.0 class FakeS3Client: def head_object(self, Bucket, Key): @@ -41,13 +39,16 @@ def head_object(self, Bucket, Key): monkeypatch.setattr("boto3.client", lambda service: FakeS3Client()) -@pytest.mark.usefixtures("openai_credentials") class TestCollectionRouteCreate: _n_documents = 5 - @openai_responses.mock() + @patch("app.api.routes.collections.get_openai_client") def test_create_collection_success( - self, client: TestClient, db: Session, user_api_key_header + self, + mock_get_openai_client, + client: TestClient, + db: Session, + user_api_key_header, ): store = DocumentStore(db) documents = store.fill(self._n_documents) @@ -60,8 +61,12 @@ def test_create_collection_success( "instructions": "Test collection assistant.", "temperature": 0.1, } + headers = user_api_key_header + mock_openai_client = get_mock_openai_client_with_vector_store() + mock_get_openai_client.return_value = mock_openai_client + response = client.post( f"{settings.API_V1_STR}/collections/create", json=body, headers=headers ) @@ -73,8 +78,8 @@ def test_create_collection_success( assert metadata["status"] == CollectionStatus.processing.value assert UUID(metadata["key"]) + # Confirm collection metadata in DB collection_id = UUID(metadata["key"]) - user = get_user_from_api_key(db, headers) collection = CollectionCrud(db, user.user_id).read_one(collection_id) diff --git a/backend/app/tests/api/routes/documents/test_route_document_info.py b/backend/app/tests/api/routes/documents/test_route_document_info.py index 3f077935a..41b3ebc52 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_info.py +++ b/backend/app/tests/api/routes/documents/test_route_document_info.py @@ -44,10 +44,7 @@ def test_info_reflects_database( assert source == target.data def test_cannot_info_unknown_document( - self, - db: Session, - route: Route, - crawler: Route, + self, db: Session, route: Route, crawler: Route ): DocumentStore.clear(db) maker = DocumentMaker(db) diff --git a/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py b/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py index 8a6d353fe..10d11f3ec 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py +++ b/backend/app/tests/api/routes/documents/test_route_document_permanent_remove.py @@ -3,11 +3,14 @@ from urllib.parse import urlparse import pytest +from unittest.mock import patch from botocore.exceptions import ClientError from moto import mock_aws from sqlmodel import Session, select +from openai import OpenAI import openai_responses +from openai_responses import OpenAIMock from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings @@ -19,7 +22,6 @@ WebCrawler, crawler, ) -from app.tests.utils.utils import openai_credentials @pytest.fixture @@ -36,16 +38,23 @@ def aws_credentials(): os.environ["AWS_DEFAULT_REGION"] = settings.AWS_DEFAULT_REGION -@pytest.mark.usefixtures("openai_credentials", "aws_credentials") +@pytest.mark.usefixtures("aws_credentials") @mock_aws class TestDocumentRoutePermanentRemove: @openai_responses.mock() + @patch("app.api.routes.documents.get_openai_client") def test_permanent_delete_document_from_s3( self, + mock_get_openai_client, db: Session, route: Route, crawler: WebCrawler, ): + openai_mock = OpenAIMock() + with openai_mock.router: + client = OpenAI(api_key="sk-test-key") + mock_get_openai_client.return_value = client + # Setup AWS aws = AmazonCloudStorageClient() aws.create() diff --git a/backend/app/tests/api/routes/documents/test_route_document_remove.py b/backend/app/tests/api/routes/documents/test_route_document_remove.py index 7b01cc2f4..292b2b10a 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_remove.py +++ b/backend/app/tests/api/routes/documents/test_route_document_remove.py @@ -1,6 +1,9 @@ import pytest import openai_responses +from openai_responses import OpenAIMock +from openai import OpenAI from sqlmodel import Session, select +from unittest.mock import patch from app.models import Document from app.tests.utils.document import ( @@ -10,8 +13,6 @@ WebCrawler, crawler, ) -from app.tests.utils.collection import get_collection -from app.tests.utils.utils import openai_credentials @pytest.fixture @@ -19,47 +20,66 @@ def route(): return Route("remove") -@pytest.mark.usefixtures("openai_credentials") class TestDocumentRouteRemove: @openai_responses.mock() + @patch("app.api.routes.documents.get_openai_client") def test_response_is_success( self, + mock_get_openai_client, db: Session, route: Route, crawler: WebCrawler, ): - store = DocumentStore(db) - response = crawler.get(route.append(store.put())) + openai_mock = OpenAIMock() + with openai_mock.router: + client = OpenAI(api_key="sk-test-key") + mock_get_openai_client.return_value = client - assert response.is_success + store = DocumentStore(db) + response = crawler.get(route.append(store.put())) + + assert response.is_success @openai_responses.mock() + @patch("app.api.routes.documents.get_openai_client") def test_item_is_soft_removed( self, + mock_get_openai_client, db: Session, route: Route, crawler: WebCrawler, ): - store = DocumentStore(db) - document = store.put() + openai_mock = OpenAIMock() + with openai_mock.router: + client = OpenAI(api_key="sk-test-key") + mock_get_openai_client.return_value = client + + store = DocumentStore(db) + document = store.put() - crawler.get(route.append(document)) - db.refresh(document) - statement = select(Document).where(Document.id == document.id) - result = db.exec(statement).one() + crawler.get(route.append(document)) + db.refresh(document) + statement = select(Document).where(Document.id == document.id) + result = db.exec(statement).one() - assert result.deleted_at is not None + assert result.deleted_at is not None @openai_responses.mock() + @patch("app.api.routes.documents.get_openai_client") def test_cannot_remove_unknown_document( self, + mock_get_openai_client, db: Session, route: Route, crawler: WebCrawler, ): - DocumentStore.clear(db) + openai_mock = OpenAIMock() + with openai_mock.router: + client = OpenAI(api_key="sk-test-key") + mock_get_openai_client.return_value = client - maker = DocumentMaker(db) - response = crawler.get(route.append(next(maker))) + DocumentStore.clear(db) + maker = DocumentMaker(db) + response = crawler.get(route.append(next(maker))) - assert response.is_error + assert response.is_error diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 46b74241a..027738714 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -1,11 +1,8 @@ -from collections.abc import Generator - import pytest -import time - from fastapi.testclient import TestClient from sqlmodel import Session from sqlalchemy import event +from collections.abc import Generator from app.core.config import settings from app.core.db import engine diff --git a/backend/app/tests/crud/collections/test_crud_collection_create.py b/backend/app/tests/crud/collections/test_crud_collection_create.py index 564230b33..d6feb064f 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/test_crud_collection_create.py @@ -1,4 +1,3 @@ -import pytest import openai_responses from sqlmodel import Session, select @@ -6,10 +5,8 @@ from app.models import DocumentCollection from app.tests.utils.document import DocumentStore from app.tests.utils.collection import get_collection -from app.tests.utils.utils import openai_credentials -@pytest.mark.usefixtures("openai_credentials") class TestCollectionCreate: _n_documents = 10 diff --git a/backend/app/tests/crud/collections/test_crud_collection_delete.py b/backend/app/tests/crud/collections/test_crud_collection_delete.py index 58d552d11..8029c5e80 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/test_crud_collection_delete.py @@ -8,16 +8,14 @@ from app.crud.rag import OpenAIAssistantCrud from app.tests.utils.document import DocumentStore from app.tests.utils.collection import get_collection, uuid_increment -from app.tests.utils.utils import openai_credentials -@pytest.mark.usefixtures("openai_credentials") class TestCollectionDelete: _n_collections = 5 @openai_responses.mock() def test_delete_marks_deleted(self, db: Session): - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) collection = get_collection(db, client) @@ -29,7 +27,7 @@ def test_delete_marks_deleted(self, db: Session): @openai_responses.mock() def test_delete_follows_insert(self, db: Session): - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) collection = get_collection(db, client) @@ -41,7 +39,7 @@ def test_delete_follows_insert(self, db: Session): @openai_responses.mock() def test_cannot_delete_others_collections(self, db: Session): - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) collection = get_collection(db, client) @@ -56,7 +54,7 @@ def test_delete_document_deletes_collections(self, db: Session): store = DocumentStore(db) documents = store.fill(1) - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): coll = get_collection(db, client) diff --git a/backend/app/tests/crud/collections/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/test_crud_collection_read_all.py index 39eb36e99..b6f63b884 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/test_crud_collection_read_all.py @@ -4,11 +4,9 @@ from sqlmodel import Session from app.crud import CollectionCrud -from app.core.config import settings from app.models import Collection from app.tests.utils.document import DocumentStore from app.tests.utils.collection import get_collection -from app.tests.utils.utils import openai_credentials def create_collections(db: Session, n: int): @@ -18,7 +16,7 @@ def create_collections(db: Session, n: int): openai_mock = OpenAIMock() with openai_mock.router: - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="sk-test-key") for _ in range(n): collection = get_collection(db, client) if crud is None: @@ -34,14 +32,10 @@ def refresh(self, db: Session): db.commit() -@pytest.mark.usefixtures("openai_credentials") class TestCollectionReadAll: _ncollections = 5 - def test_number_read_is_expected( - self, - db: Session, - ): + def test_number_read_is_expected(self, db: Session): db.query(Collection).delete() owner = create_collections(db, self._ncollections) diff --git a/backend/app/tests/crud/collections/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/test_crud_collection_read_one.py index 22edba532..ed0d31ad4 100644 --- a/backend/app/tests/crud/collections/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/test_crud_collection_read_one.py @@ -8,7 +8,6 @@ from app.crud import CollectionCrud from app.tests.utils.document import DocumentStore from app.tests.utils.collection import get_collection, uuid_increment -from app.tests.utils.utils import openai_credentials def mk_collection(db: Session): @@ -17,13 +16,12 @@ def mk_collection(db: Session): openai_mock = OpenAIMock() with openai_mock.router: - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="sk-test-key") collection = get_collection(db, client) crud = CollectionCrud(db, collection.owner_id) return crud.create(collection, documents) -@pytest.mark.usefixtures("openai_credentials") class TestDatabaseReadOne: def test_can_select_valid_id(self, db: Session): collection = mk_collection(db) @@ -35,7 +33,6 @@ def test_can_select_valid_id(self, db: Session): def test_cannot_select_others_collections(self, db: Session): collection = mk_collection(db) - other = collection.owner_id + 1 crud = CollectionCrud(db, other) with pytest.raises(NoResultFound): diff --git a/backend/app/tests/crud/documents/test_crud_document_read_one.py b/backend/app/tests/crud/documents/test_crud_document_read_one.py index 463ea02ea..a3de8f49d 100644 --- a/backend/app/tests/crud/documents/test_crud_document_read_one.py +++ b/backend/app/tests/crud/documents/test_crud_document_read_one.py @@ -33,11 +33,7 @@ def test_cannot_select_invalid_id(self, db: Session, store: DocumentStore): assert exc_info.value.status_code == 404 assert "Document not found" in str(exc_info.value.detail) - def test_cannot_read_others_documents( - self, - db: Session, - store: DocumentStore, - ): + def test_cannot_read_others_documents(self, db: Session, store: DocumentStore): document = store.put() other = DocumentStore(db) diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 0a68125c5..411af9947 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -36,7 +36,7 @@ def get_collection(db: Session, client=None): ) if client is None: - client = OpenAI(api_key=settings.OPENAI_API_KEY) + client = OpenAI(api_key="test_api_key") vector_store = client.vector_stores.create() assistant = client.beta.assistants.create( diff --git a/backend/app/tests/utils/openai.py b/backend/app/tests/utils/openai.py index 778d4804f..6f11bbf51 100644 --- a/backend/app/tests/utils/openai.py +++ b/backend/app/tests/utils/openai.py @@ -1,6 +1,7 @@ from typing import Optional import time +from unittest.mock import MagicMock from openai.types.beta import Assistant as OpenAIAssistant from openai.types.beta.assistant import ToolResources, ToolResourcesFileSearch from openai.types.beta.assistant_tool import FileSearchTool @@ -37,3 +38,33 @@ def mock_openai_assistant( top_p=1.0, reasoning_effort=None, ) + + +def get_mock_openai_client_with_vector_store(): + mock_client = MagicMock() + + # Vector store + mock_vector_store = MagicMock() + mock_vector_store.id = "mock_vector_store_id" + mock_client.vector_stores.create.return_value = mock_vector_store + + # File upload + polling + mock_file_batch = MagicMock() + mock_file_batch.file_counts.completed = 2 + mock_file_batch.file_counts.total = 2 + mock_client.vector_stores.file_batches.upload_and_poll.return_value = ( + mock_file_batch + ) + + # File list + mock_client.vector_stores.files.list.return_value = {"data": []} + + # Assistant + mock_assistant = MagicMock() + mock_assistant.id = "mock_assistant_id" + mock_assistant.name = "Mock Assistant" + mock_assistant.model = "gpt-4o" + mock_assistant.instructions = "Mock instructions" + mock_client.beta.assistants.create.return_value = mock_assistant + + return mock_client diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index d48ae5d33..ae4a7beef 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -3,6 +3,7 @@ from uuid import UUID from typing import Type, TypeVar + import pytest from pydantic import EmailStr from fastapi.testclient import TestClient @@ -17,11 +18,6 @@ T = TypeVar("T") -@pytest.fixture(scope="class") -def openai_credentials(): - settings.OPENAI_API_KEY = "sk-fake123" - - def random_lower_string() -> str: return "".join(random.choices(string.ascii_lowercase, k=32))