diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index ba31baf26..f28658866 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -37,9 +37,9 @@ def create_collection( def test_collection_info_processing( - db: Session, client: TestClient, normal_user_api_key_headers + db: Session, client: TestClient, user_api_key_header ): - headers = normal_user_api_key_headers + headers = user_api_key_header user = get_user_from_api_key(db, headers) collection = create_collection(db, user, status=CollectionStatus.processing) @@ -58,9 +58,9 @@ def test_collection_info_processing( def test_collection_info_successful( - db: Session, client: TestClient, normal_user_api_key_headers + db: Session, client: TestClient, user_api_key_header ): - headers = normal_user_api_key_headers + headers = user_api_key_header user = get_user_from_api_key(db, headers) collection = create_collection( db, user, status=CollectionStatus.successful, with_llm=True @@ -80,10 +80,8 @@ def test_collection_info_successful( assert data["llm_service_name"] == "gpt-4o" -def test_collection_info_failed( - db: Session, client: TestClient, normal_user_api_key_headers -): - headers = normal_user_api_key_headers +def test_collection_info_failed(db: Session, client: TestClient, user_api_key_header): + headers = user_api_key_header user = get_user_from_api_key(db, headers) collection = create_collection(db, user, status=CollectionStatus.failed) diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 8a458c444..ef779211e 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -47,7 +47,7 @@ class TestCollectionRouteCreate: @openai_responses.mock() def test_create_collection_success( - self, client: TestClient, db: Session, normal_user_api_key_headers + self, client: TestClient, db: Session, user_api_key_header ): store = DocumentStore(db) documents = store.fill(self._n_documents) @@ -60,7 +60,7 @@ def test_create_collection_success( "instructions": "Test collection assistant.", "temperature": 0.1, } - headers = normal_user_api_key_headers + headers = user_api_key_header response = client.post( f"{settings.API_V1_STR}/collections/create", json=body, headers=headers diff --git a/backend/app/tests/api/routes/documents/test_route_document_upload.py b/backend/app/tests/api/routes/documents/test_route_document_upload.py index 6e9c26ded..4c6abaaa9 100644 --- a/backend/app/tests/api/routes/documents/test_route_document_upload.py +++ b/backend/app/tests/api/routes/documents/test_route_document_upload.py @@ -25,7 +25,7 @@ def put(self, route: Route, scratch: Path): with scratch.open("rb") as fp: return self.client.post( str(route), - headers=self.normal_user_api_key_headers, + headers=self.user_api_key_header, files={ "src": (str(scratch), fp, mtype), }, @@ -45,8 +45,8 @@ def route(): @pytest.fixture -def uploader(client: TestClient, normal_user_api_key_headers: dict[str, str]): - return WebUploader(client, normal_user_api_key_headers) +def uploader(client: TestClient, user_api_key_header: dict[str, str]): + return WebUploader(client, user_api_key_header) @pytest.fixture(scope="class") diff --git a/backend/app/tests/api/routes/test_assistants.py b/backend/app/tests/api/routes/test_assistants.py index 8755889de..8d2366960 100644 --- a/backend/app/tests/api/routes/test_assistants.py +++ b/backend/app/tests/api/routes/test_assistants.py @@ -30,7 +30,7 @@ def assistant_id(): def test_ingest_assistant_success( mock_fetch_assistant, client: TestClient, - normal_user_api_key_headers: dict[str, str], + user_api_key_header: dict[str, str], ): """Test successful assistant ingestion from OpenAI.""" mock_assistant = mock_openai_assistant() @@ -39,7 +39,7 @@ def test_ingest_assistant_success( response = client.post( f"/api/v1/assistant/{mock_assistant.id}/ingest", - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 201 @@ -53,7 +53,7 @@ def test_create_assistant_success( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - normal_user_api_key_headers: dict, + user_api_key_header: dict, ): """Test successful assistant creation with OpenAI vector store ID verification.""" @@ -62,7 +62,7 @@ def test_create_assistant_success( response = client.post( "/api/v1/assistant", json=assistant_create_payload, - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 201 @@ -92,7 +92,7 @@ def test_create_assistant_invalid_vector_store( mock_verify_vector_ids, client: TestClient, assistant_create_payload: dict, - normal_user_api_key_headers: dict, + user_api_key_header: dict, ): """Test failure when one or more vector store IDs are invalid.""" @@ -106,7 +106,7 @@ def test_create_assistant_invalid_vector_store( response = client.post( "/api/v1/assistant", json=payload, - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 400 @@ -175,7 +175,7 @@ def test_update_assistant_invalid_vector_store( def test_update_assistant_not_found( client: TestClient, - normal_user_api_key_headers: dict, + user_api_key_header: dict, ): """Test failure when updating a non-existent assistant.""" update_payload = {"name": "Updated Assistant"} @@ -185,7 +185,7 @@ def test_update_assistant_not_found( response = client.patch( f"/api/v1/assistant/{non_existent_id}", json=update_payload, - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 404 @@ -217,14 +217,14 @@ def test_get_assistant_success( def test_get_assistant_not_found( client: TestClient, - normal_user_api_key_headers: dict, + user_api_key_header: dict, ): """Test failure when fetching a non-existent assistant.""" non_existent_id = str(uuid4()) response = client.get( f"/api/v1/assistant/{non_existent_id}", - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 404 @@ -258,27 +258,27 @@ def test_list_assistants_success( def test_list_assistants_invalid_pagination( client: TestClient, - normal_user_api_key_headers: dict, + user_api_key_header: dict, ): """Test assistants list with invalid pagination parameters.""" # Test negative skip response = client.get( "/api/v1/assistant/?skip=-1&limit=10", - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 422 # Test limit too high response = client.get( "/api/v1/assistant/?skip=0&limit=101", - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 422 # Test limit too low response = client.get( "/api/v1/assistant/?skip=0&limit=0", - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 422 @@ -304,14 +304,14 @@ def test_delete_assistant_success( def test_delete_assistant_not_found( client: TestClient, - normal_user_api_key_headers: dict, + user_api_key_header: dict, ): """Test failure when deleting a non-existent assistant.""" non_existent_id = str(uuid4()) response = client.delete( f"/api/v1/assistant/{non_existent_id}", - headers=normal_user_api_key_headers, + headers=user_api_key_header, ) assert response.status_code == 404 diff --git a/backend/app/tests/api/routes/test_responses.py b/backend/app/tests/api/routes/test_responses.py index 242e6fdca..aac0a2beb 100644 --- a/backend/app/tests/api/routes/test_responses.py +++ b/backend/app/tests/api/routes/test_responses.py @@ -16,7 +16,7 @@ @patch("app.api.routes.responses.OpenAI") @patch("app.api.routes.responses.get_provider_credential") def test_responses_endpoint_success( - mock_get_credential, mock_openai, db, normal_user_api_key_headers: dict[str, str] + mock_get_credential, mock_openai, db, user_api_key_header: dict[str, str] ): """Test the /responses endpoint for successful response creation.""" # Setup mock credentials @@ -48,9 +48,7 @@ def test_responses_endpoint_success( "callback_url": "http://example.com/callback", } - response = client.post( - "/responses", json=request_data, headers=normal_user_api_key_headers - ) + response = client.post("/responses", json=request_data, headers=user_api_key_header) assert response.status_code == 200 response_json = response.json() @@ -67,7 +65,7 @@ def test_responses_endpoint_without_vector_store( mock_get_credential, mock_openai, db, - normal_user_api_key_headers, + user_api_key_header, ): """Test the /responses endpoint when assistant has no vector store configured.""" # Setup mock credentials @@ -107,9 +105,7 @@ def test_responses_endpoint_without_vector_store( "callback_url": "http://example.com/callback", } - response = client.post( - "/responses", json=request_data, headers=normal_user_api_key_headers - ) + response = client.post("/responses", json=request_data, headers=user_api_key_header) assert response.status_code == 200 response_json = response.json() assert response_json["success"] is True diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 51dde40dc..46b74241a 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -11,6 +11,7 @@ from app.core.db import engine from app.api.deps import get_db from app.main import app +from app.models import APIKeyPublic from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers, get_api_key_by_email from app.seed_data.seed_data import seed_database @@ -64,12 +65,24 @@ def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str] @pytest.fixture(scope="function") -def superuser_api_key_headers(db: Session) -> dict[str, str]: +def superuser_api_key_header(db: Session) -> dict[str, str]: api_key = get_api_key_by_email(db, settings.FIRST_SUPERUSER) - return {"X-API-KEY": api_key} + return {"X-API-KEY": api_key.key} @pytest.fixture(scope="function") -def normal_user_api_key_headers(db: Session) -> dict[str, str]: +def user_api_key_header(db: Session) -> dict[str, str]: api_key = get_api_key_by_email(db, settings.EMAIL_TEST_USER) - return {"X-API-KEY": api_key} + return {"X-API-KEY": api_key.key} + + +@pytest.fixture(scope="function") +def superuser_api_key(db: Session) -> APIKeyPublic: + api_key = get_api_key_by_email(db, settings.FIRST_SUPERUSER) + return api_key + + +@pytest.fixture(scope="function") +def user_api_key(db: Session) -> APIKeyPublic: + api_key = get_api_key_by_email(db, settings.EMAIL_TEST_USER) + return api_key diff --git a/backend/app/tests/utils/document.py b/backend/app/tests/utils/document.py index ab72d2cef..e6234a1c5 100644 --- a/backend/app/tests/utils/document.py +++ b/backend/app/tests/utils/document.py @@ -113,18 +113,18 @@ def append(self, doc: Document, suffix: str = None): @dataclass class WebCrawler: client: TestClient - normal_user_api_key_headers: dict[str, str] + user_api_key_header: dict[str, str] def get(self, route: Route): return self.client.get( str(route), - headers=self.normal_user_api_key_headers, + headers=self.user_api_key_header, ) def delete(self, route: Route): return self.client.delete( str(route), - headers=self.normal_user_api_key_headers, + headers=self.user_api_key_header, ) @@ -158,5 +158,5 @@ def to_dict(self): @pytest.fixture -def crawler(client: TestClient, normal_user_api_key_headers: dict[str, str]): - return WebCrawler(client, normal_user_api_key_headers) +def crawler(client: TestClient, user_api_key_header: dict[str, str]): + return WebCrawler(client, user_api_key_header) diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 505d3e208..d48ae5d33 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -46,11 +46,11 @@ def get_superuser_token_headers(client: TestClient) -> dict[str, str]: return headers -def get_api_key_by_email(db: Session, email: EmailStr) -> str: +def get_api_key_by_email(db: Session, email: EmailStr) -> APIKeyPublic: user = get_user_by_email(session=db, email=email) api_key = get_api_key_by_user_id(db, user_id=user.id) - return api_key.key + return api_key def get_user_id_by_email(db: Session) -> int: