diff --git a/backend/app/api/docs/collections/create.md b/backend/app/api/docs/collections/create.md index 3917d7c19..258b49fad 100644 --- a/backend/app/api/docs/collections/create.md +++ b/backend/app/api/docs/collections/create.md @@ -7,10 +7,12 @@ pipeline: * Create an OpenAI [Vector Store](https://platform.openai.com/docs/api-reference/vector-stores) based on those File's. -* Attach the Vector Store to an OpenAI +* [To be deprecated] Attach the Vector Store to an OpenAI [Assistant](https://platform.openai.com/docs/api-reference/assistants). Use parameters in the request body relevant to an Assistant to flesh out - its configuration. + its configuration. Note that an assistant will only be created when you pass both + "model" and "instruction" in the request body otherwise only a vector store will be + created from the documents given. If any one of the OpenAI interactions fail, all OpenAI resources are cleaned up. If a Vector Store is unable to be created, for example, @@ -19,9 +21,10 @@ OpenAI. Failure can occur from OpenAI being down, or some parameter value being invalid. It can also fail due to document types not be accepted. This is especially true for PDFs that may not be parseable. -The immediate response from the endpoint is `collection_job` object which is -going to contain the collection "job ID", status and action type ("CREATE"). -Once the collection has been created, information about the collection will -be returned to the user via the callback URL. If a callback URL is not provided, -clients can poll the `collection job info` endpoint with the `id` in the -`collection_job` object returned as it is the `job id`, to retrieve the same information. +Vector store/assistant will be created asynchronously. The immediate response +from this endpoint is `collection_job` object which is going to contain +the collection "job ID" and status.Once the collection has been created, +information about the collection will be returned to the user via the +callback URL. If a callback URL is not provided, clients can check the +`collection job info` endpoint with the `job_id`, to retrieve the +information about the creation of collection. diff --git a/backend/app/api/docs/collections/delete.md b/backend/app/api/docs/collections/delete.md index 63a1e3cf4..da2c75cf8 100644 --- a/backend/app/api/docs/collections/delete.md +++ b/backend/app/api/docs/collections/delete.md @@ -7,7 +7,8 @@ Remove a collection from the platform. This is a two step process: No action is taken on the documents themselves: the contents of the documents that were a part of the collection remain unchanged, those documents can still be accessed via the documents endpoints. The response from this -endpoint will be a `collection_job` object which will contain the collection `job ID`, -status and action type ("DELETE"). when you take the id returned and use the collection job -info endpoint, if the job is successful, you will get the status as successful and nothing will -be returned as the collection as it has been deleted and marked as deleted. +endpoint will be a `collection_job` object which will contain the collection `job_id` and +status. when you take the id returned and use the collection job +info endpoint, if the job is successful, you will get the status as successful. +Additionally, if a `callback_url` was provided in the request body, +you will receive a message indicating whether the deletion was successful or if it failed. diff --git a/backend/app/api/docs/collections/docs.md b/backend/app/api/docs/collections/docs.md deleted file mode 100644 index a121bf688..000000000 --- a/backend/app/api/docs/collections/docs.md +++ /dev/null @@ -1,3 +0,0 @@ -List document IDs associated with a given collection. Documents -returned are not only stored by the AI platform, but also by OpenAI -OpenAI. diff --git a/backend/app/api/docs/collections/info.md b/backend/app/api/docs/collections/info.md index 4fa32e2ea..f9be78df7 100644 --- a/backend/app/api/docs/collections/info.md +++ b/backend/app/api/docs/collections/info.md @@ -1,4 +1,4 @@ -Retrieve detailed information about a specific collection by its ID from the collection table. Note that this endpoint CANNOT be used as a polling endpoint for collection creation because an entry will be made in the collection table only after the resource creation and association has been successful. - -This endpoint returns metadata for the collection, including its project, organization, +Retrieve detailed information about `a specific collection by its ID` from the collection table. This endpoint returns the collection object including its project, organization, timestamps, and associated LLM service details (`llm_service_id`). + +Additionally, if the `include_docs` flag in the request body is true then you will get a list of document IDs associated with a given collection as well. Documents returned are not only stored by the AI platform, but also by OpenAI. diff --git a/backend/app/api/docs/collections/job_info.md b/backend/app/api/docs/collections/job_info.md index e785967b5..34d9a342e 100644 --- a/backend/app/api/docs/collections/job_info.md +++ b/backend/app/api/docs/collections/job_info.md @@ -1,12 +1,9 @@ -Retrieve information about a collection job by the collection job ID. This endpoint can be considered the polling endpoint for collection creation job. This endpoint provides detailed status and metadata for a specific collection job -in the AI platform. It is especially useful for: +Retrieve information about a collection job by the collection job ID. This endpoint provides detailed status and metadata for a specific collection job in the AI platform. It is especially useful for: -* Fetching the collection job object containing the ID which will be collection job id, collection ID, status of the job as well as error message. +* Fetching the collection job object, including the collection job ID, the current status, and the associated collection details. * If the job has finished, has been successful and it was a job of creation of collection then this endpoint will fetch the associated collection details from the collection table, including: - - `llm_service_id`: The OpenAI assistant or model used for the collection. - - Collection metadata such as ID, project, organization, and timestamps. + - `llm_service_id`: The OpenAI assistant or model used for the collection. + - Collection metadata such as ID, project, organization, and timestamps. -* If the job of delete collection was successful, we will get the status as successful and nothing will be returned as collection. - -* Containing a simplified error messages in the retrieved collection job object when a job has failed. +* If the delete-collection job succeeds, the status is set to “successful” and the `collection_key` contains the ID of the collection that has been deleted. diff --git a/backend/app/api/docs/collections/list.md b/backend/app/api/docs/collections/list.md index c32bb31f9..eec8f312b 100644 --- a/backend/app/api/docs/collections/list.md +++ b/backend/app/api/docs/collections/list.md @@ -1,2 +1,6 @@ List _active_ collections -- collections that have been created but not deleted + +If a vector store was created - `llm_service_name` and `llm_service_id` in the response denote the name of the vector store (eg. 'openai vector store') and its id. + +[To be deprecated] If an assistant was created, `llm_service_name` and `llm_service_id` in the response denote the name of the model used in the assistant (eg. 'gpt-4o') and assistant id. diff --git a/backend/app/api/routes/collection_job.py b/backend/app/api/routes/collection_job.py index bddfe8d95..5636ed8f4 100644 --- a/backend/app/api/routes/collection_job.py +++ b/backend/app/api/routes/collection_job.py @@ -10,7 +10,12 @@ CollectionCrud, CollectionJobCrud, ) -from app.models import CollectionJobStatus, CollectionJobPublic, CollectionActionType +from app.models import ( + CollectionJobStatus, + CollectionIDPublic, + CollectionActionType, + CollectionJobPublic, +) from app.models.collection import CollectionPublic from app.utils import APIResponse, load_description from app.services.collections.helpers import extract_error_message @@ -21,7 +26,7 @@ @router.get( - "/info/jobs/{job_id}", + "/jobs/{job_id}", description=load_description("collections/job_info.md"), response_model=APIResponse[CollectionJobPublic], ) @@ -35,16 +40,21 @@ def collection_job_info( job_out = CollectionJobPublic.model_validate(collection_job) - if ( - collection_job.status == CollectionJobStatus.SUCCESSFUL - and collection_job.action_type == CollectionActionType.CREATE - and collection_job.collection_id - ): - collection_crud = CollectionCrud(session, current_user.project_id) - collection = collection_crud.read_one(collection_job.collection_id) - job_out.collection = CollectionPublic.model_validate(collection) - - if collection_job.status == CollectionJobStatus.FAILED and job_out.error_message: - job_out.error_message = extract_error_message(job_out.error_message) + if collection_job.collection_id: + if ( + collection_job.action_type == CollectionActionType.CREATE + and collection_job.status == CollectionJobStatus.SUCCESSFUL + ): + collection_crud = CollectionCrud(session, current_user.project_id) + collection = collection_crud.read_one(collection_job.collection_id) + job_out.collection = CollectionPublic.model_validate(collection) + + elif collection_job.action_type == CollectionActionType.DELETE: + job_out.collection = CollectionIDPublic(id=collection_job.collection_id) + + if collection_job.status == CollectionJobStatus.FAILED: + raw_error = getattr(collection_job, "error_message", None) + error_message = extract_error_message(raw_error) + job_out.error_message = error_message return APIResponse.success_response(data=job_out) diff --git a/backend/app/api/routes/collections.py b/backend/app/api/routes/collections.py index c6210bebc..78a40ae7e 100644 --- a/backend/app/api/routes/collections.py +++ b/backend/app/api/routes/collections.py @@ -1,12 +1,10 @@ -import inspect import logging from uuid import UUID from typing import List -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, Body from fastapi import Path as FastPath - from app.api.deps import SessionDep, CurrentUserOrgProject from app.crud import ( CollectionCrud, @@ -18,15 +16,17 @@ CollectionJobStatus, CollectionActionType, CollectionJobCreate, + CollectionJobPublic, + CollectionJobImmediatePublic, + CollectionWithDocsPublic, ) from app.models.collection import ( - ResponsePayload, CreationRequest, + CallbackRequest, DeletionRequest, CollectionPublic, ) from app.utils import APIResponse, load_description -from app.services.collections.helpers import extract_error_message from app.services.collections import ( create_collection as create_service, delete_collection as delete_service, @@ -34,12 +34,47 @@ logger = logging.getLogger(__name__) + router = APIRouter(prefix="/collections", tags=["collections"]) +collection_callback_router = APIRouter() + + +@collection_callback_router.post( + "{$callback_url}", + name="collection_callback", +) +def collection_callback_notification(body: APIResponse[CollectionJobPublic]): + """ + Callback endpoint specification for collection creation/deletion. + + The callback will receive: + - On success: APIResponse with success=True and data containing CollectionJobPublic + - On failure: APIResponse with success=False and error message + - metadata field will always be included if provided in the request + """ + ... + + +@router.get( + "/", + description=load_description("collections/list.md"), + response_model=APIResponse[List[CollectionPublic]], +) +def list_collections( + session: SessionDep, + current_user: CurrentUserOrgProject, +): + collection_crud = CollectionCrud(session, current_user.project_id) + rows = collection_crud.read_all() + + return APIResponse.success_response(rows) @router.post( - "/create", + "/", description=load_description("collections/create.md"), + response_model=APIResponse[CollectionJobImmediatePublic], + callbacks=collection_callback_router.routes, ) def create_collection( session: SessionDep, @@ -55,35 +90,52 @@ def create_collection( ) ) - this = inspect.currentframe() - route = router.url_path_for(this.f_code.co_name) - payload = ResponsePayload( - status="processing", route=route, key=str(collection_job.id) + # True iff both model and instructions were provided in the request body + with_assistant = bool( + getattr(request, "model", None) and getattr(request, "instructions", None) ) create_service.start_job( db=session, request=request, - payload=payload, collection_job_id=collection_job.id, project_id=current_user.project_id, organization_id=current_user.organization_id, + with_assistant=with_assistant, ) - return APIResponse.success_response(collection_job) + metadata = None + if not with_assistant: + metadata = { + "note": ( + "This job will create a vector store only (no Assistant). " + "Assistant creation happens when both 'model' and 'instructions' are included." + ) + } + + return APIResponse.success_response( + CollectionJobImmediatePublic.model_validate(collection_job), metadata=metadata + ) -@router.post( - "/delete", +@router.delete( + "/{collection_id}", description=load_description("collections/delete.md"), + response_model=APIResponse[CollectionJobImmediatePublic], + callbacks=collection_callback_router.routes, ) def delete_collection( session: SessionDep, current_user: CurrentUserOrgProject, - request: DeletionRequest, + collection_id: UUID = FastPath(description="Collection to delete"), + request: CallbackRequest | None = Body(default=None), ): - collection_crud = CollectionCrud(session, current_user.project_id) - collection = collection_crud.read_one(request.collection_id) + _ = CollectionCrud(session, current_user.project_id).read_one(collection_id) + + deletion_request = DeletionRequest( + collection_id=collection_id, + callback_url=request.callback_url if request else None, + ) collection_job_crud = CollectionJobCrud(session, current_user.project_id) collection_job = collection_job_crud.create( @@ -91,74 +143,49 @@ def delete_collection( action_type=CollectionActionType.DELETE, project_id=current_user.project_id, status=CollectionJobStatus.PENDING, - collection_id=collection.id, + collection_id=collection_id, ) ) - this = inspect.currentframe() - route = router.url_path_for(this.f_code.co_name) - payload = ResponsePayload( - status="processing", route=route, key=str(collection_job.id) - ) - delete_service.start_job( db=session, - request=request, - payload=payload, - collection=collection, + request=deletion_request, collection_job_id=collection_job.id, project_id=current_user.project_id, organization_id=current_user.organization_id, ) - return APIResponse.success_response(collection_job) + return APIResponse.success_response( + CollectionJobImmediatePublic.model_validate(collection_job) + ) @router.get( - "/info/{collection_id}", + "/{collection_id}", description=load_description("collections/info.md"), - response_model=APIResponse[CollectionPublic], + response_model=APIResponse[CollectionWithDocsPublic], ) def collection_info( session: SessionDep, current_user: CurrentUserOrgProject, collection_id: UUID = FastPath(description="Collection to retrieve"), + include_docs: bool = Query( + True, + description="If true, include documents linked to this collection", + ), + skip: int = Query(0, ge=0), + limit: int = Query(100, gt=0, le=100), ): collection_crud = CollectionCrud(session, current_user.project_id) collection = collection_crud.read_one(collection_id) - return APIResponse.success_response(collection) - + collection_with_docs = CollectionWithDocsPublic.model_validate(collection) -@router.get( - "/list", - description=load_description("collections/list.md"), - response_model=APIResponse[List[CollectionPublic]], -) -def list_collections( - session: SessionDep, - current_user: CurrentUserOrgProject, -): - collection_crud = CollectionCrud(session, current_user.project_id) - rows = collection_crud.read_all() + if include_docs: + document_collection_crud = DocumentCollectionCrud(session) + docs = document_collection_crud.read(collection, skip, limit) + collection_with_docs.documents = [ + DocumentPublic.model_validate(doc) for doc in docs + ] - return APIResponse.success_response(rows) - - -@router.post( - "/docs/{collection_id}", - description=load_description("collections/docs.md"), - response_model=APIResponse[List[DocumentPublic]], -) -def collection_documents( - session: SessionDep, - current_user: CurrentUserOrgProject, - collection_id: UUID = FastPath(description="Collection to retrieve"), - skip: int = Query(0, ge=0), - limit: int = Query(100, gt=0, le=100), -): - collection_crud = CollectionCrud(session, current_user.project_id) - document_collection_crud = DocumentCollectionCrud(session) - collection = collection_crud.read_one(collection_id) - data = document_collection_crud.read(collection, skip, limit) - return APIResponse.success_response(data) + return APIResponse.success_response(collection_with_docs) diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index 8fad2a70c..ed5a431ce 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -23,7 +23,7 @@ resolve_transformer, ) from app.crud import CollectionCrud, DocumentCrud -from app.crud.rag import OpenAIAssistantCrud +from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud from app.models import ( Document, DocumentPublic, @@ -31,6 +31,7 @@ Message, TransformationJobInfo, ) +from app.services.collections.helpers import pick_service_for_documennt from app.utils import APIResponse, get_openai_client, load_description @@ -164,11 +165,16 @@ def remove_doc( ) a_crud = OpenAIAssistantCrud(client) + v_crud = OpenAIVectorStoreCrud(client) d_crud = DocumentCrud(session, current_user.project_id) c_crud = CollectionCrud(session, current_user.project_id) + document = d_crud.read_one(doc_id) - document = d_crud.delete(doc_id) - data = c_crud.delete(document, a_crud) + remote = pick_service_for_documennt( + session, doc_id, a_crud, v_crud + ) # assistant crud or vector store crud + c_crud.delete(document, remote) + d_crud.delete(doc_id) return APIResponse.success_response( Message(message="Document Deleted Successfully") @@ -189,13 +195,17 @@ def permanent_delete_doc( session, current_user.organization_id, current_user.project_id ) a_crud = OpenAIAssistantCrud(client) + v_crud = OpenAIVectorStoreCrud(client) d_crud = DocumentCrud(session, current_user.project_id) c_crud = CollectionCrud(session, current_user.project_id) storage = get_cloud_storage(session=session, project_id=current_user.project_id) document = d_crud.read_one(doc_id) - c_crud.delete(document, a_crud) + remote = pick_service_for_documennt( + session, doc_id, a_crud, v_crud + ) # assistant crud or vector store crud + c_crud.delete(document, remote) storage.delete(document.object_store_url) d_crud.delete(doc_id) diff --git a/backend/app/crud/collection/collection.py b/backend/app/crud/collection/collection.py index d218ef2a9..3c83912a4 100644 --- a/backend/app/crud/collection/collection.py +++ b/backend/app/crud/collection/collection.py @@ -6,6 +6,7 @@ from fastapi import HTTPException from sqlmodel import Session, select, and_ +from sqlalchemy.exc import IntegrityError from app.models import Document, Collection, DocumentCollection from app.core.util import now @@ -20,19 +21,6 @@ def __init__(self, session: Session, project_id: int): self.project_id = project_id def _update(self, collection: Collection): - if not collection.project_id: - collection.project_id = self.project_id - elif collection.project_id != self.project_id: - err = ( - f"Invalid collection ownership: owner_project={self.project_id} " - f"attempter={collection.project_id}" - ) - logger.error( - "[CollectionCrud._update] Permission error | " - f"{{'collection_id': '{collection.id}', 'error': '{err}'}}" - ) - raise PermissionError(err) - self.session.add(collection) self.session.commit() self.session.refresh(collection) @@ -53,29 +41,18 @@ def _exists(self, collection: Collection) -> bool: return present def create( - self, - collection: Collection, - documents: Optional[list[Document]] = None, - ): + self, collection: Collection, documents: list[Document] | None = None + ) -> Collection: + self.session.add(collection) try: - existing = self.read_one(collection.id) - except HTTPException as e: - if e.status_code == 404: - self.session.add(collection) - self.session.commit() - self.session.refresh(collection) - else: - raise - else: - logger.warning( - "[CollectionCrud.create] Collection already present | " - f"{{'collection_id': '{collection.id}'}}" - ) - return existing + self.session.commit() + except IntegrityError: + self.session.rollback() + return self.read_one(collection.id) + self.session.refresh(collection) if documents: - dc_crud = DocumentCollectionCrud(self.session) - dc_crud.create(collection, documents) + DocumentCollectionCrud(self.session).create(collection, documents) return collection @@ -116,6 +93,12 @@ def read_all(self): collections = self.session.exec(statement).all() return collections + def delete_by_id(self, collection_id: UUID) -> Collection: + coll = self.read_one(collection_id) + coll.deleted_at = now() + + return self._update(coll) + @ft.singledispatchmethod def delete(self, model, remote): # remote should be an OpenAICrud try: @@ -145,7 +128,10 @@ def _(self, model: Document, remote): DocumentCollection, DocumentCollection.collection_id == Collection.id, ) - .where(DocumentCollection.document_id == model.id) + .where( + DocumentCollection.document_id == model.id, + Collection.deleted_at.is_(None), + ) .distinct() ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 6a9f41852..b2f294025 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -6,15 +6,20 @@ from .assistants import Assistant, AssistantBase, AssistantCreate, AssistantUpdate -from .collection import Collection, CollectionPublic +from .collection import ( + Collection, + CollectionPublic, + CollectionIDPublic, + CollectionWithDocsPublic, +) from .collection_job import ( CollectionActionType, CollectionJob, - CollectionJobBase, CollectionJobStatus, CollectionJobUpdate, CollectionJobPublic, CollectionJobCreate, + CollectionJobImmediatePublic, ) from .credentials import ( Credential, diff --git a/backend/app/models/collection.py b/backend/app/models/collection.py index 9e5f866fd..e09f56226 100644 --- a/backend/app/models/collection.py +++ b/backend/app/models/collection.py @@ -3,9 +3,10 @@ from typing import Any, Optional from sqlmodel import Field, Relationship, SQLModel -from pydantic import HttpUrl +from pydantic import HttpUrl, model_validator from app.core.util import now +from app.models.document import DocumentPublic from .organization import Organization from .project import Project @@ -36,22 +37,7 @@ class Collection(SQLModel, table=True): project: Project = Relationship(back_populates="collections") -class ResponsePayload(SQLModel): - """Response metadata for background jobs—gives status, route, a UUID key, - and creation time.""" - - status: str - route: str - key: str = Field(default_factory=lambda: str(uuid4())) - time: datetime = Field(default_factory=now) - - @classmethod - def now(cls): - """Returns current UTC time without timezone info""" - return now() - - -# pydantic models - +# Request models class DocumentOptions(SQLModel): documents: list[UUID] = Field( description="List of document IDs", @@ -73,27 +59,57 @@ class AssistantOptions(SQLModel): # Fields to be passed along to OpenAI. They must be a subset of # parameters accepted by the OpenAI.clien.beta.assistants.create # API. - model: str = Field( + model: Optional[str] = Field( + default=None, description=( + "**[To Be Deprecated]** " "OpenAI model to attach to this assistant. The model " "must be compatable with the assistants API; see the " "OpenAI [model documentation](https://platform.openai.com/docs/models/compare) for more." ), ) - instructions: str = Field( + + instructions: Optional[str] = Field( + default=None, description=( - "Assistant instruction. Sometimes referred to as the " '"system" prompt.' + "**[To Be Deprecated]** " + "Assistant instruction. Sometimes referred to as the " + '"system" prompt.' ), ) temperature: float = Field( default=1e-6, description=( + "**[To Be Deprecated]** " "Model temperature. The default is slightly " "greater-than zero because it is [unknown how OpenAI " "handles zero](https://community.openai.com/t/clarifications-on-setting-temperature-0/886447/5)." ), ) + @model_validator(mode="before") + def _assistant_fields_all_or_none(cls, values: dict[str, Any]) -> dict[str, Any]: + def norm(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + return s if s else None + return x # let Pydantic handle non-strings + + model = norm(values.get("model")) + instructions = norm(values.get("instructions")) + + if (model is None) ^ (instructions is None): + raise ValueError( + "To create an Assistant, provide BOTH 'model' and 'instructions'. " + "If you only want a vector store, remove both fields." + ) + + values["model"] = model + values["instructions"] = instructions + return values + class CallbackRequest(SQLModel): callback_url: Optional[HttpUrl] = Field( @@ -108,7 +124,7 @@ class CreationRequest( CallbackRequest, ): def extract_super_type(self, cls: "CreationRequest"): - for field_name in cls.__fields__.keys(): + for field_name in cls.model_fields.keys(): field_value = getattr(self, field_name) yield (field_name, field_value) @@ -117,6 +133,13 @@ class DeletionRequest(CallbackRequest): collection_id: UUID = Field(description="Collection to delete") +# Response models + + +class CollectionIDPublic(SQLModel): + id: UUID + + class CollectionPublic(SQLModel): id: UUID llm_service_id: str @@ -127,3 +150,7 @@ class CollectionPublic(SQLModel): inserted_at: datetime updated_at: datetime deleted_at: datetime | None = None + + +class CollectionWithDocsPublic(CollectionPublic): + documents: list[DocumentPublic] | None = None diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py index af7eda6eb..4739b16c0 100644 --- a/backend/app/models/collection_job.py +++ b/backend/app/models/collection_job.py @@ -3,9 +3,10 @@ from datetime import datetime from sqlmodel import Field, SQLModel, Column, Text +from pydantic import ConfigDict from app.core.util import now -from app.models.collection import CollectionPublic +from app.models.collection import CollectionPublic, CollectionIDPublic class CollectionJobStatus(str, Enum): @@ -20,31 +21,26 @@ class CollectionActionType(str, Enum): DELETE = "DELETE" -class CollectionJobBase(SQLModel): - action_type: CollectionActionType = Field( - nullable=False, description="Type of operation" - ) - collection_id: UUID | None = Field( - foreign_key="collection.id", nullable=True, ondelete="CASCADE" - ) - project_id: int = Field( - foreign_key="project.id", nullable=False, ondelete="CASCADE" - ) - - -class CollectionJob(CollectionJobBase, table=True): +class CollectionJob(SQLModel, table=True): """Database model for tracking collection operations.""" __tablename__ = "collection_jobs" id: UUID = Field(default_factory=uuid4, primary_key=True) - status: CollectionJobStatus = Field( default=CollectionJobStatus.PENDING, nullable=False, description="Current job status", ) - + action_type: CollectionActionType = Field( + nullable=False, description="Type of operation" + ) + collection_id: UUID | None = Field( + foreign_key="collection.id", nullable=True, ondelete="CASCADE" + ) + project_id: int = Field( + foreign_key="project.id", nullable=False, ondelete="CASCADE" + ) task_id: str = Field(nullable=True) trace_id: str | None = Field( default=None, description="Tracing ID for correlating logs and traces." @@ -63,7 +59,20 @@ class CollectionJob(CollectionJobBase, table=True): description="Last time the job record was updated", ) + @property + def job_id(self) -> UUID: + return self.id + + @property + def job_inserted_at(self) -> datetime: + return self.inserted_at + + @property + def job_updated_at(self) -> datetime: + return self.updated_at + +# Request models class CollectionJobCreate(SQLModel): collection_id: UUID | None = None status: CollectionJobStatus @@ -79,13 +88,18 @@ class CollectionJobUpdate(SQLModel): trace_id: str | None = None -class CollectionJobPublic(SQLModel): - id: UUID - action_type: CollectionActionType - collection_id: UUID | None = None +##Response models +class CollectionJobBasePublic(SQLModel): + job_id: UUID status: CollectionJobStatus - error_message: str | None = None - inserted_at: datetime - updated_at: datetime - collection: CollectionPublic | None = None + +class CollectionJobImmediatePublic(CollectionJobBasePublic): + job_inserted_at: datetime + job_updated_at: datetime + + +class CollectionJobPublic(CollectionJobBasePublic): + action_type: CollectionActionType + collection: CollectionPublic | CollectionIDPublic | None = None + error_message: str | None = None diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index d424c5333..ed83e4a89 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -20,20 +20,22 @@ CollectionJob, Collection, CollectionJobUpdate, + CollectionPublic, + CollectionJobPublic, ) from app.models.collection import ( - ResponsePayload, CreationRequest, AssistantOptions, ) from app.services.collections.helpers import ( _backout, batch_documents, - SilentCallback, - WebHookCallback, + extract_error_message, + OPENAI_VECTOR_STORE, ) from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client +from app.utils import get_openai_client, send_callback, APIResponse + logger = logging.getLogger(__name__) @@ -41,9 +43,9 @@ def start_job( db: Session, request: CreationRequest, - payload: ResponsePayload, project_id: int, collection_job_id: UUID, + with_assistant: bool, organization_id: int, ) -> str: trace_id = correlation_id.get() or "N/A" @@ -57,9 +59,9 @@ def start_job( function_path="app.services.collections.create_collection.execute_job", project_id=project_id, job_id=str(collection_job_id), - payload=payload.model_dump(), trace_id=trace_id, - request=request.model_dump(), + with_assistant=with_assistant, + request=request.model_dump(mode="json"), organization_id=organization_id, ) @@ -71,142 +73,246 @@ def start_job( return collection_job_id +def build_success_payload( + collection_job: CollectionJob, collection: Collection +) -> dict: + """ + { + "success": true, + "data": { job fields + full collection }, + "error": null, + "metadata": null + } + """ + collection_public = CollectionPublic.model_validate(collection) + job_public = CollectionJobPublic.model_validate( + collection_job, + update={"collection": collection_public}, + ) + return APIResponse.success_response(job_public).model_dump( + mode="json", exclude={"data": {"error_message"}} + ) + + +def build_failure_payload(collection_job: CollectionJob, error_message: str) -> dict: + """ + { + "success": false, + "data": { job fields, collection: null }, + "error": "something went wrong", + "metadata": null + } + """ + # ensure `collection` is explicitly null in the payload + job_public = CollectionJobPublic.model_validate( + collection_job, + update={"collection": None}, + ) + return APIResponse.failure_response( + extract_error_message(error_message), job_public + ).model_dump( + mode="json", + exclude={"data": {"error_message"}}, + ) + + +def _cleanup_remote_resources( + assistant, + assistant_crud, + vector_store, + vector_store_crud, +) -> None: + """Best-effort cleanup of partially created remote resources.""" + try: + if assistant is not None and assistant_crud is not None: + _backout(assistant_crud, assistant.id) + elif vector_store is not None and vector_store_crud is not None: + _backout(vector_store_crud, vector_store.id) + else: + logger.warning( + "[create_collection._backout] Skipping: no resource/crud available" + ) + except Exception: + logger.warning("[create_collection.execute_job] Backout failed") + + +def _mark_job_failed( + project_id: int, + job_id: str, + err: Exception, + collection_job: CollectionJob | None, +) -> CollectionJob | None: + """Update job row to FAILED with error_message; return latest job or None.""" + try: + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + if collection_job is None: + collection_job = collection_job_crud.read_one(UUID(job_id)) + collection_job = collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.FAILED, + error_message=str(err), + ), + ) + return collection_job + except Exception: + logger.warning("[create_collection.execute_job] Failed to mark job as FAILED") + return None + + def execute_job( request: dict, project_id: int, organization_id: int, - payload: dict, task_id: str, job_id: str, + with_assistant: bool, task_instance, ) -> None: """ Worker entrypoint scheduled by start_job. + Orchestrates: job state, client/storage init, batching, vector-store upload, + optional assistant creation, collection persistence, linking, callbacks, and cleanup. """ start_time = time.time() - try: - with Session(engine) as session: - creation_request = CreationRequest(**request) - payload = ResponsePayload(**payload) + # Keep references for potential backout/cleanup on failure + assistant = None + assistant_crud = None + vector_store = None + vector_store_crud = None + collection_job = None - job_id = UUID(job_id) + try: + creation_request = CreationRequest(**request) + job_uuid = UUID(job_id) + with Session(engine) as session: collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_id) + collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( - job_id, + job_uuid, CollectionJobUpdate( - task_id=task_id, status=CollectionJobStatus.PROCESSING + task_id=task_id, + status=CollectionJobStatus.PROCESSING, ), ) client = get_openai_client(session, organization_id, project_id) + storage = get_cloud_storage(session=session, project_id=project_id) - callback = ( - SilentCallback(payload) - if creation_request.callback_url is None - else WebHookCallback(creation_request.callback_url, payload) + # Batch documents for upload, and flatten for linking/metrics later + document_crud = DocumentCrud(session, project_id) + docs_batches = batch_documents( + document_crud, + creation_request.documents, + creation_request.batch_size, ) + flat_docs = [doc for batch in docs_batches for doc in batch] - storage = get_cloud_storage(session=session, project_id=project_id) - document_crud = DocumentCrud(session, project_id) + vector_store_crud = OpenAIVectorStoreCrud(client) + vector_store = vector_store_crud.create() + list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + + # if with_assistant is true, create assistant backed by the vector store + if with_assistant: assistant_crud = OpenAIAssistantCrud(client) - vector_store_crud = OpenAIVectorStoreCrud(client) - - try: - vector_store = vector_store_crud.create() - - docs_batches = batch_documents( - document_crud, - creation_request.documents, - creation_request.batch_size, - ) - flat_docs = [doc for batch in docs_batches for doc in batch] - - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - file_sizes_kb = [ - storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs - ] - - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) - - assistant_options = dict( - creation_request.extract_super_type(AssistantOptions) - ) - assistant = assistant_crud.create(vector_store.id, **assistant_options) - - collection_id = uuid4() - collection_crud = CollectionCrud(session, project_id) - collection = Collection( - id=collection_id, - project_id=project_id, - organization_id=organization_id, - llm_service_id=assistant.id, - llm_service_name=creation_request.model, - ) - - collection_crud.create(collection) - collection_data = collection_crud.read_one(collection.id) - - if flat_docs: - DocumentCollectionCrud(session).create(collection_data, flat_docs) - - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - collection_id=collection.id, - ), - ) - - elapsed = time.time() - start_time - logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Sizes: %s KB | Types: %s", - collection_id, - elapsed, - len(flat_docs), - file_sizes_kb, - list(file_exts), - ) - - callback.success(collection.model_dump(mode="json")) - - except Exception as err: - logger.error( - "[create_collection.execute_job] Collection Creation Failed | " - "{'collection_job_id': '%s', 'error': '%s'}", - job_id, - str(err), - exc_info=True, - ) - - if "assistant" in locals(): - _backout(assistant_crud, assistant.id) - - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(err), - ), - ) - - callback.fail(str(err)) - - except Exception as outer_err: + + # Filter out None to avoid sending unset options + assistant_options = dict( + creation_request.extract_super_type(AssistantOptions) + ) + assistant_options = { + k: v for k, v in assistant_options.items() if v is not None + } + + assistant = assistant_crud.create(vector_store.id, **assistant_options) + llm_service_id = assistant.id + llm_service_name = assistant_options.get("model") or "assistant" + + logger.info( + "[execute_job] Assistant created | assistant_id=%s, vector_store_id=%s", + assistant.id, + vector_store.id, + ) + else: + # If no assistant, the collection points directly at the vector store + llm_service_id = vector_store.id + llm_service_name = OPENAI_VECTOR_STORE + logger.info( + "[execute_job] Skipping assistant creation | with_assistant=False" + ) + + file_exts = {doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname} + file_sizes_kb = [ + storage.get_file_size_kb(doc.object_store_url) for doc in flat_docs + ] + + with Session(engine) as session: + collection_crud = CollectionCrud(session, project_id) + + collection_id = uuid4() + collection = Collection( + id=collection_id, + project_id=project_id, + organization_id=organization_id, + llm_service_id=llm_service_id, + llm_service_name=llm_service_name, + ) + collection_crud.create(collection) + collection = collection_crud.read_one(collection.id) + + # Link documents to the new collection + if flat_docs: + DocumentCollectionCrud(session).create(collection, flat_docs) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + collection_id=collection.id, + ), + ) + + success_payload = build_success_payload(collection_job, collection) + + elapsed = time.time() - start_time + logger.info( + "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Sizes: %s KB | Types: %s", + collection_id, + elapsed, + len(flat_docs), + file_sizes_kb, + list(file_exts), + ) + + if creation_request.callback_url: + send_callback(creation_request.callback_url, success_payload) + + except Exception as err: logger.error( - "[create_collection.execute_job] Unexpected Error during collection creation: %s", - str(outer_err), + "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(err), exc_info=True, ) - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(outer_err), - ), + _cleanup_remote_resources( + assistant=assistant, + assistant_crud=assistant_crud, + vector_store=vector_store, + vector_store_crud=vector_store_crud, ) + + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + send_callback(creation_request.callback_url, failure_payload) diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index 088647c31..ca337b796 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -3,20 +3,21 @@ from sqlmodel import Session from asgi_correlation_id import correlation_id -from sqlalchemy.exc import SQLAlchemyError from app.core.db import engine from app.crud import CollectionCrud, CollectionJobCrud -from app.crud.rag import OpenAIAssistantCrud -from app.models import CollectionJobStatus, CollectionJobUpdate -from app.models.collection import Collection, DeletionRequest -from app.services.collections.helpers import ( - SilentCallback, - WebHookCallback, - ResponsePayload, +from app.crud.rag import OpenAIAssistantCrud, OpenAIVectorStoreCrud +from app.models import ( + CollectionJobStatus, + CollectionJobUpdate, + CollectionJob, + CollectionJobPublic, + CollectionIDPublic, ) +from app.models.collection import DeletionRequest +from app.services.collections.helpers import extract_error_message, OPENAI_VECTOR_STORE from app.celery.utils import start_low_priority_job -from app.utils import get_openai_client +from app.utils import get_openai_client, send_callback, APIResponse logger = logging.getLogger(__name__) @@ -25,10 +26,8 @@ def start_job( db: Session, request: DeletionRequest, - collection: Collection, project_id: int, collection_job_id: UUID, - payload: ResponsePayload, organization_id: int, ) -> str: trace_id = correlation_id.get() or "N/A" @@ -42,23 +41,105 @@ def start_job( function_path="app.services.collections.delete_collection.execute_job", project_id=project_id, job_id=str(collection_job_id), - collection_id=str(collection.id), + collection_id=str(request.collection_id), trace_id=trace_id, - request=request.model_dump(), - payload=payload.model_dump(), + request=request.model_dump(mode="json"), organization_id=organization_id, ) logger.info( "[delete_collection.start_job] Job scheduled to delete collection | " - f"Job_id={collection_job_id}, project_id={project_id}, task_id={task_id}, collection_id={collection.id}" + f"Job_id={collection_job_id}, project_id={project_id}, task_id={task_id}, collection_id={request.collection_id}" ) return collection_job_id +def build_success_payload(collection_job: CollectionJob, collection_id: UUID) -> dict: + """ + success: true + data: { job_id, status, collection: { id } } + error: null + metadata: null + """ + collection_public = CollectionIDPublic(id=collection_id) + job_public = CollectionJobPublic.model_validate( + collection_job, + update={"collection": collection_public}, + ) + return APIResponse.success_response(job_public).model_dump( + mode="json", + exclude_none=True, + ) + + +def build_failure_payload( + collection_job: CollectionJob, collection_id: UUID, error_message: str +) -> dict: + """ + success: false + data: { job_id, status, collection: { id } } + error: "something went wrong" + metadata: null + """ + collection_public = CollectionIDPublic(id=collection_id) + job_public = CollectionJobPublic.model_validate( + collection_job, + update={"collection": collection_public}, + ) + return APIResponse.failure_response( + extract_error_message(error_message), job_public + ).model_dump(mode="json", exclude={"data": {"error_message"}}) + + +def _mark_job_failed_and_callback( + *, + project_id: int, + collection_id: UUID, + job_id: UUID, + err: Exception, + callback_url: str | None, +) -> None: + """ + Common failure handler: + - mark job as FAILED with error_message + - log error + - send failure callback (if configured) + """ + collection_job = None + try: + with Session(engine) as session: + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job_crud.update( + job_id, + CollectionJobUpdate( + status=CollectionJobStatus.FAILED, + error_message=str(err), + ), + ) + collection_job = collection_job_crud.read_one(job_id) + except Exception: + logger.warning("[delete_collection.execute_job] Failed to mark job as FAILED") + + logger.error( + "[delete_collection.execute_job] deletion failed | " + "{'collection_id': '%s', 'error': '%s', 'job_id': '%s'}", + str(collection_id), + str(err), + str(job_id), + exc_info=True, + ) + + if callback_url and collection_job: + failure_payload = build_failure_payload( + collection_job=collection_job, + collection_id=collection_id, + error_message=str(err), + ) + send_callback(callback_url, failure_payload) + + def execute_job( request: dict, - payload: dict, project_id: int, organization_id: int, task_id: str, @@ -66,90 +147,74 @@ def execute_job( collection_id: str, task_instance, ) -> None: + """Celery worker entrypoint for deleting a collection (both remote and local).""" + deletion_request = DeletionRequest(**request) - payload = ResponsePayload(**payload) - callback = ( - SilentCallback(payload) - if deletion_request.callback_url is None - else WebHookCallback(deletion_request.callback_url, payload) - ) + collection_id = UUID(collection_id) + job_uuid = UUID(job_id) - if not isinstance(collection_id, UUID): - collection_id = UUID(str(collection_id)) - if not isinstance(job_id, UUID): - job_id = UUID(str(job_id)) + collection_job = None + client = None try: with Session(engine) as session: - client = get_openai_client(session, organization_id, project_id) - collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_id) + collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( - job_id, + job_uuid, CollectionJobUpdate( - task_id=task_id, status=CollectionJobStatus.PROCESSING + task_id=task_id, + status=CollectionJobStatus.PROCESSING, ), ) - assistant_crud = OpenAIAssistantCrud(client) - collection_crud = CollectionCrud(session, project_id) - - collection = collection_crud.read_one(collection_id) - - try: - result = collection_crud.delete(collection, assistant_crud) - - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.SUCCESSFUL, - error_message=None, - ), - ) - - logger.info( - "[delete_collection.execute_job] Collection deleted successfully | {'collection_id': '%s', 'job_id': '%s'}", - str(collection.id), - str(job_id), - ) - callback.success(result.model_dump(mode="json")) - - except (ValueError, PermissionError, SQLAlchemyError) as err: - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(err), - ), - ) - - logger.error( - "[delete_collection.execute_job] Failed to delete collection | {'collection_id': '%s', 'error': '%s', 'job_id': '%s'}", - str(collection.id), - str(err), - str(job_id), - exc_info=True, - ) - callback.fail(str(err)) + client = get_openai_client(session, organization_id, project_id) - except Exception as err: - collection_job_crud.update( - collection_job.id, - CollectionJobUpdate( - status=CollectionJobStatus.FAILED, - error_message=str(err), - ), + collection = CollectionCrud(session, project_id).read_one(collection_id) + + # Identify which external service (assistant/vector store) this collection belongs to + service = (collection.llm_service_name or "").strip().lower() + is_vector = service == OPENAI_VECTOR_STORE + llm_service_id = collection.llm_service_id + + # Delete the corresponding OpenAI resource (vector store or assistant) + if is_vector: + OpenAIVectorStoreCrud(client).delete(llm_service_id) + else: + OpenAIAssistantCrud(client).delete(llm_service_id) + + with Session(engine) as session: + CollectionCrud(session, project_id).delete_by_id(collection_id) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job_crud.update( + collection_job.id, + CollectionJobUpdate( + status=CollectionJobStatus.SUCCESSFUL, + error_message=None, + ), + ) + collection_job = collection_job_crud.read_one(collection_job.id) + + logger.info( + "[delete_collection.execute_job] Collection deleted successfully | " + "{'collection_id': '%s', 'job_id': '%s'}", + str(collection_id), + str(job_uuid), ) + if deletion_request.callback_url and collection_job: + success_payload = build_success_payload( + collection_job=collection_job, + collection_id=collection_id, + ) + send_callback(deletion_request.callback_url, success_payload) - logger.error( - "[delete_collection.execute_job] Unexpected error during deletion | " - "{'collection_id': '%s', 'error': '%s', 'error_type': '%s', 'job_id': '%s'}", - str(collection.id), - str(err), - type(err).__name__, - str(job_id), - exc_info=True, + except Exception as err: + _mark_job_failed_and_callback( + project_id=project_id, + collection_id=collection_id, + job_id=job_uuid, + err=err, + callback_url=deletion_request.callback_url, ) - callback.fail(str(err)) diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 158994c69..5ca1c6759 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -4,20 +4,19 @@ import re from uuid import UUID from typing import List -from dataclasses import asdict, replace -from pydantic import HttpUrl +from sqlmodel import select from openai import OpenAIError -from app.core.util import post_callback from app.crud.document import DocumentCrud -from app.models.collection import ResponsePayload -from app.crud.rag import OpenAIAssistantCrud -from app.utils import APIResponse +from app.models import DocumentCollection, Collection logger = logging.getLogger(__name__) +# llm service name for when only an openai vector store is being made +OPENAI_VECTOR_STORE = "openai vector store" + def extract_error_message(err: Exception) -> str: """Extract a concise, user-facing message from an exception, preferring `error.message` @@ -70,60 +69,36 @@ def batch_documents( return docs_batches -# functions related to callback handling - -class CallbackHandler: - def __init__(self, payload: ResponsePayload): - self.payload = payload - - def fail(self, body): - raise NotImplementedError() - - def success(self, body): - raise NotImplementedError() - - -class SilentCallback(CallbackHandler): - def fail(self, body): - logger.info(f"[SilentCallback.fail] Silent callback failure") - return - - def success(self, body): - logger.info(f"[SilentCallback.success] Silent callback success") - return - - -class WebHookCallback(CallbackHandler): - def __init__(self, url: HttpUrl, payload: ResponsePayload): - super().__init__(payload) - self.url = url - logger.info( - f"[WebHookCallback.init] Initialized webhook callback | {{'url': '{url}'}}" - ) - - def __call__(self, response: APIResponse, status: str): - time = ResponsePayload.now() - payload = replace(self.payload, status=status, time=time) - response.metadata = asdict(payload) - logger.info( - f"[WebHookCallback.call] Posting callback | {{'url': '{self.url}', 'status': '{status}'}}" - ) - post_callback(self.url, response) - - def fail(self, body): - logger.warning(f"[WebHookCallback.fail] Callback failed | {{'body': '{body}'}}") - self(APIResponse.failure_response(body), "incomplete") - - def success(self, body): - logger.info(f"[WebHookCallback.success] Callback succeeded") - self(APIResponse.success_response(body), "complete") - - -def _backout(crud: OpenAIAssistantCrud, assistant_id: str): +def _backout(crud, llm_service_id: str): """Best-effort cleanup: attempt to delete the assistant by ID""" try: - crud.delete(assistant_id) + crud.delete(llm_service_id) except OpenAIError as err: logger.error( - f"[backout] Failed to delete assistant | {{'assistant_id': '{assistant_id}', 'error': '{str(err)}'}}", + f"[backout] Failed to delete resource | {{'llm_service_id': '{llm_service_id}', 'error': '{str(err)}'}}", exc_info=True, ) + + +# Even though this function is used in the documents router, it's kept here for now since the assistant creation logic will +# eventually be removed from Kaapi. Once that happens, this function can be safely deleted - +def pick_service_for_documennt(session, doc_id: UUID, a_crud, v_crud): + """ + Return the correct remote (v_crud or a_crud) for this document + by inspecting an active linked Collection's llm_service_name. + Defaults to a_crud if not vector store. + """ + coll = session.exec( + select(Collection) + .join(DocumentCollection, DocumentCollection.collection_id == Collection.id) + .where( + DocumentCollection.document_id == doc_id, + Collection.deleted_at.is_(None), + ) + .limit(1) + ).first() + + service = ( + (getattr(coll, "llm_service_name", "") or "").strip().lower() if coll else "" + ) + return v_crud if service == OPENAI_VECTOR_STORE else a_crud diff --git a/backend/app/tests/api/routes/collections/test_collection_delete.py b/backend/app/tests/api/routes/collections/test_collection_delete.py new file mode 100644 index 000000000..99e61d671 --- /dev/null +++ b/backend/app/tests/api/routes/collections/test_collection_delete.py @@ -0,0 +1,124 @@ +from uuid import UUID, uuid4 +from unittest.mock import patch + +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.models import CollectionJobStatus +from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_collection + + +@patch("app.api.routes.collections.delete_service.start_job") +def test_delete_collection_calls_start_job_and_returns_job( + mock_start_job, + db: Session, + client: TestClient, + user_api_key_header: dict[str, str], + user_api_key, +): + """ + Happy path: + - Existing collection for the current project + - No callback request body + - Creates a DELETE CollectionJob with PENDING status + - Calls delete_service.start_job with correct arguments + """ + project = get_project(db, "Dalgo") + collection = get_collection(db, project) + + resp = client.request( + "DELETE", + f"{settings.API_V1_STR}/collections/{collection.id}", + headers=user_api_key_header, + ) + + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert data["status"] == CollectionJobStatus.PENDING + assert data["job_inserted_at"] + assert data["job_updated_at"] + + mock_start_job.assert_called_once() + kwargs = mock_start_job.call_args.kwargs + + assert "db" in kwargs + assert kwargs["project_id"] == user_api_key.project_id + assert kwargs["organization_id"] == user_api_key.organization_id + + returned_job_id = UUID(data["job_id"]) + assert kwargs["collection_job_id"] == returned_job_id + + deletion_request = kwargs["request"] + assert deletion_request.collection_id == collection.id + assert deletion_request.callback_url is None + + +@patch("app.api.routes.collections.delete_service.start_job") +def test_delete_collection_with_callback_url_passes_it_to_start_job( + mock_start_job, + db: Session, + client: TestClient, + user_api_key_header: dict[str, str], + user_api_key, +): + """ + When a callback_url is provided in the request body, ensure it is passed + into the DeletionRequest and then into delete_service.start_job. + """ + project = get_project(db, "Dalgo") + collection = get_collection(db, project) + + payload = { + "callback_url": "https://example.com/collections/delete-callback", + } + + resp = client.request( + "DELETE", + f"{settings.API_V1_STR}/collections/{collection.id}", + json=payload, + headers=user_api_key_header, + ) + + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert data["status"] == CollectionJobStatus.PENDING + + mock_start_job.assert_called_once() + kwargs = mock_start_job.call_args.kwargs + + assert kwargs["project_id"] == user_api_key.project_id + assert kwargs["organization_id"] == user_api_key.organization_id + assert kwargs["collection_job_id"] == UUID(data["job_id"]) + + deletion_request = kwargs["request"] + assert deletion_request.collection_id == collection.id + assert str(deletion_request.callback_url) == payload["callback_url"] + + +@patch("app.api.routes.collections.delete_service.start_job") +def test_delete_collection_not_found_returns_404_and_does_not_start_job( + mock_start_job, + client: TestClient, + user_api_key_header: dict[str, str], +): + """ + For a random UUID that doesn't correspond to any collection, we expect: + - 404 response + - delete_service.start_job is NOT called + """ + random_id = uuid4() + + resp = client.request( + "DELETE", + f"{settings.API_V1_STR}/collections/{random_id}", + headers=user_api_key_header, + ) + + assert resp.status_code == 404 + mock_start_job.assert_not_called() diff --git a/backend/app/tests/api/routes/collections/test_collection_info.py b/backend/app/tests/api/routes/collections/test_collection_info.py index 2317ef241..90f8b80c1 100644 --- a/backend/app/tests/api/routes/collections/test_collection_info.py +++ b/backend/app/tests/api/routes/collections/test_collection_info.py @@ -1,138 +1,187 @@ -from uuid import uuid4, UUID +import uuid from typing import Optional from fastapi.testclient import TestClient -from sqlmodel import Session +from sqlmodel import Session, select from app.core.config import settings -from app.core.util import now -from app.models import ( - Collection, - CollectionJobCreate, - CollectionActionType, - CollectionJobStatus, - CollectionJobUpdate, -) -from app.crud import CollectionJobCrud, CollectionCrud - - -def create_collection( +from app.tests.utils.utils import get_project, get_document +from app.tests.utils.collection import get_collection, get_vector_store_collection +from app.crud import DocumentCollectionCrud +from app.models import Collection, Document + + +def link_document_to_collection( db: Session, - user, - with_llm: bool = False, + collection: Collection, + document: Optional[Document] = None, ): - """Create a Collection row (optionally prefilled with LLM service fields).""" - llm_service_id = None - llm_service_name = None - if with_llm: - llm_service_id = f"asst_{uuid4()}" - llm_service_name = "gpt-4o" - - collection = Collection( - id=uuid4(), - organization_id=user.organization_id, - project_id=user.project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, - ) + """ + Utility used in tests to associate a Document with a Collection so that + DocumentCollectionCrud.read(...) will return something. + + If you have not given documents to this function then this uses your `get_document` helper + to provide documents to the DocumentCollectionCrud.create. + """ - return CollectionCrud(db, user.project_id).create(collection) + if document is None: + document = get_document(db) + crud = DocumentCollectionCrud(db) + crud.create(collection, [document]) -def create_collection_job( + return document + + +def test_collection_info_returns_assistant_collection_with_docs( + client: TestClient, db: Session, - user, - collection_id: Optional[UUID] = None, - action_type: CollectionActionType = CollectionActionType.CREATE, - status: CollectionJobStatus = CollectionJobStatus.PENDING, + user_api_key_header, ): - """Create a CollectionJob row (uses create schema for clarity).""" - job_in = CollectionJobCreate( - collection_id=collection_id, - project_id=user.project_id, - action_type=action_type, - status=status, + """ + Happy path: + - Assistant-style collection (get_collection) + - include_docs = True (default) + - At least one document linked + """ + + project = get_project(db, "Dalgo") + collection = get_collection(db, project) + + document = link_document_to_collection(db, collection) + + response = client.get( + f"{settings.API_V1_STR}/collections/{collection.id}", + headers=user_api_key_header, ) - collection_job = CollectionJobCrud(db, user.project_id).create(job_in) - if collection_job.status == CollectionJobStatus.FAILED: - job_in = CollectionJobUpdate( - error_message="Something went wrong during the collection job process." - ) - collection_job = CollectionJobCrud(db, user.project_id).update( - collection_job.id, job_in - ) + assert response.status_code == 200 + + data = response.json() - return collection_job + assert data["success"] is True + payload = data["data"] + assert str(collection.id) == payload["id"] + assert payload["project_id"] == project.id -def test_collection_info_processing( - db: Session, client: "TestClient", user_api_key_header, user_api_key + docs = payload.get("documents", []) + assert isinstance(docs, list) + assert len(docs) >= 1 + + doc_ids = {d["id"] for d in docs} + assert str(document.id) in doc_ids + + +def test_collection_info_include_docs_false_returns_no_docs( + client: TestClient, + db: Session, + user_api_key_header, ): - headers = user_api_key_header + """ + When include_docs=false, the endpoint should not populate the documents list. + """ + project = get_project(db, "Dalgo") + collection = get_collection(db, project) - collection_job = create_collection_job(db, user_api_key) + link_document_to_collection(db, collection) response = client.get( - f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", - headers=headers, + f"{settings.API_V1_STR}/collections/{collection.id}", + headers=user_api_key_header, + params={"include_docs": "false"}, ) assert response.status_code == 200 - data = response.json()["data"] - assert data["status"] == CollectionJobStatus.PENDING - assert data["inserted_at"] is not None - assert data["collection_id"] == collection_job.collection_id - assert data["updated_at"] is not None + data = response.json() + payload = data["data"] + assert payload["id"] == str(collection.id) + assert payload["llm_service_name"] == "gpt-4o" + assert payload["llm_service_id"] == collection.llm_service_id + assert payload["documents"] is None -def test_collection_info_successful( - db: Session, client: "TestClient", user_api_key_header, user_api_key + +def test_collection_info_pagination_skip_and_limit( + client: TestClient, + db: Session, + user_api_key_header, ): - headers = user_api_key_header + """ + Verify skip & limit are passed through to DocumentCollectionCrud.read. + We create multiple document links and then request a paginated slice. + """ + project = get_project(db, "Dalgo") + collection = get_collection(db, project) - collection = create_collection(db, user_api_key, with_llm=True) - collection_job = create_collection_job( - db, user_api_key, collection.id, status=CollectionJobStatus.SUCCESSFUL - ) + documents = db.exec( + select(Document).where(Document.deleted_at.is_(None)).limit(2) + ).all() + + assert len(documents) >= 2, "Test requires at least two documents in the DB" + + DocumentCollectionCrud(db).create(collection, documents) response = client.get( - f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", - headers=headers, + f"{settings.API_V1_STR}/collections/{collection.id}", + headers=user_api_key_header, + params={"skip": 1, "limit": 1}, ) assert response.status_code == 200 - data = response.json()["data"] - assert data["id"] == str(collection_job.id) - assert data["status"] == CollectionJobStatus.SUCCESSFUL - assert data["action_type"] == CollectionActionType.CREATE - assert data["collection_id"] == str(collection.id) + data = response.json() + payload = data["data"] + docs_resp = payload.get("documents", []) - assert data["collection"] is not None - col = data["collection"] - assert col["id"] == str(collection.id) - assert col["llm_service_id"] == collection.llm_service_id - assert col["llm_service_name"] == "gpt-4o" + assert len(docs_resp) == 1 -def test_collection_info_failed( - db: Session, client: "TestClient", user_api_key_header, user_api_key +def test_collection_info_vector_store_collection( + client: TestClient, + db: Session, + user_api_key_header, ): - headers = user_api_key_header + """ + Ensure the endpoint also works for vector-store-style collections created + via get_vector_store_collection. + """ + project = get_project(db, "Dalgo") + collection = get_vector_store_collection(db, project) - collection_job = create_collection_job( - db, user_api_key, status=CollectionJobStatus.FAILED - ) + link_document_to_collection(db, collection) response = client.get( - f"{settings.API_V1_STR}/collections/info/jobs/{collection_job.id}", - headers=headers, + f"{settings.API_V1_STR}/collections/{collection.id}", + headers=user_api_key_header, ) assert response.status_code == 200 - data = response.json()["data"] - assert data["status"] == CollectionJobStatus.FAILED - assert data["error_message"] is not None + data = response.json() + payload = data["data"] + + assert payload["id"] == str(collection.id) + assert payload["llm_service_name"] == "openai vector store" + assert payload["llm_service_id"] == collection.llm_service_id + + docs = payload.get("documents", []) + assert len(docs) >= 1 + + +def test_collection_info_not_found_returns_404( + client: TestClient, + user_api_key_header, +): + """ + For a random UUID that doesn't correspond to any collection, we expect 404. + """ + random_id = uuid.uuid4() + + response = client.get( + f"{settings.API_V1_STR}/collections/{random_id}", + headers=user_api_key_header, + ) + + assert response.status_code == 404 diff --git a/backend/app/tests/api/routes/collections/test_collection_job_info.py b/backend/app/tests/api/routes/collections/test_collection_job_info.py new file mode 100644 index 000000000..8735b51d1 --- /dev/null +++ b/backend/app/tests/api/routes/collections/test_collection_job_info.py @@ -0,0 +1,158 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.core.config import settings +from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_collection, get_collection_job +from app.models import ( + CollectionActionType, + CollectionJobStatus, +) + + +def test_collection_info_processing( + db: Session, client: "TestClient", user_api_key_header +): + headers = user_api_key_header + project = get_project(db, "Dalgo") + + collection_job = get_collection_job( + db, project, status=CollectionJobStatus.PROCESSING + ) + + resp = client.get( + f"{settings.API_V1_STR}/collections/jobs/{collection_job.id}", + headers=headers, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["job_id"] == str(collection_job.id) + assert data["status"] == CollectionJobStatus.PROCESSING + + assert data.get("collection") is None + + +def test_collection_info_create_successful( + db: Session, client: "TestClient", user_api_key_header +): + headers = user_api_key_header + project = get_project(db, "Dalgo") + + collection = get_collection(db, project) + + collection_job = get_collection_job( + db, project, collection_id=collection.id, status=CollectionJobStatus.SUCCESSFUL + ) + + resp = client.get( + f"{settings.API_V1_STR}/collections/jobs/{collection_job.id}", + headers=headers, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["job_id"] == str(collection_job.id) + assert data["status"] == CollectionJobStatus.SUCCESSFUL + assert data["action_type"] == CollectionActionType.CREATE + + assert data["collection"] is not None + col = data["collection"] + assert col["id"] == str(collection.id) + assert col["llm_service_id"] == collection.llm_service_id + assert col["llm_service_name"] == "gpt-4o" + + +def test_collection_info_create_failed( + db: Session, client: "TestClient", user_api_key_header +): + headers = user_api_key_header + project = get_project(db, "Dalgo") + + collection_job = get_collection_job( + db, + project, + status=CollectionJobStatus.FAILED, + error_message="something went wrong", + ) + + resp = client.get( + f"{settings.API_V1_STR}/collections/jobs/{collection_job.id}", + headers=headers, + ) + body = resp.json() + assert body["success"] is True + + data = body["data"] + + assert data["job_id"] == str(collection_job.id) + assert data["status"] == CollectionJobStatus.FAILED + assert data["action_type"] == CollectionActionType.CREATE + assert data["error_message"] == "something went wrong" + + assert data["collection"] is None + + +def test_collection_info_delete_successful( + db: Session, client: "TestClient", user_api_key_header +): + headers = user_api_key_header + project = get_project(db, "Dalgo") + + collection = get_collection(db, project) + + collection_job = get_collection_job( + db, + project, + collection_id=collection.id, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.SUCCESSFUL, + ) + + resp = client.get( + f"{settings.API_V1_STR}/collections/jobs/{collection_job.id}", + headers=headers, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["job_id"] == str(collection_job.id) + assert data["status"] == CollectionJobStatus.SUCCESSFUL + assert data["action_type"] == CollectionActionType.DELETE + + assert data["collection"] is not None + col = data["collection"] + assert col["id"] == str(collection.id) + + +def test_collection_info_delete_failed( + db: Session, client: "TestClient", user_api_key_header +): + headers = user_api_key_header + project = get_project(db, "Dalgo") + + collection = get_collection(db, project) + + collection_job = get_collection_job( + db, + project, + collection_id=collection.id, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.FAILED, + error_message="something went wrong", + ) + + resp = client.get( + f"{settings.API_V1_STR}/collections/jobs/{collection_job.id}", + headers=headers, + ) + body = resp.json() + assert body["success"] is True + + data = body["data"] + assert data["job_id"] == str(collection_job.id) + assert data["status"] == CollectionJobStatus.FAILED + assert data["action_type"] == CollectionActionType.DELETE + assert data["error_message"] == "something went wrong" + + assert data["collection"] is not None diff --git a/backend/app/tests/api/routes/collections/test_collection_list.py b/backend/app/tests/api/routes/collections/test_collection_list.py new file mode 100644 index 000000000..f7507c129 --- /dev/null +++ b/backend/app/tests/api/routes/collections/test_collection_list.py @@ -0,0 +1,127 @@ +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.tests.utils.utils import get_project +from app.tests.utils.collection import ( + get_collection, + get_vector_store_collection, +) + + +def test_list_collections_returns_api_response( + client: TestClient, + user_api_key_header, +): + """ + Basic sanity check: + - Endpoint returns 200 + - Response is wrapped in APIResponse + - `data` is a list + """ + + response = client.get( + f"{settings.API_V1_STR}/collections/", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + + data = response.json() + assert "success" in data + assert "data" in data + assert isinstance(data["data"], list) + + +def test_list_collections_includes_assistant_collection( + db: Session, + client: TestClient, + user_api_key_header, +): + """ + Ensure that a newly created assistant-style collection (get_collection) + appears in the list for the current project. + """ + + project = get_project(db, "Dalgo") + + response_before = client.get( + f"{settings.API_V1_STR}/collections/", + headers=user_api_key_header, + ) + assert response_before.status_code == 200 + + collection = get_collection(db, project) + + response_after = client.get( + f"{settings.API_V1_STR}/collections/", + headers=user_api_key_header, + ) + assert response_after.status_code == 200 + + after_data = response_after.json() + assert after_data["success"] is True + after_payload = after_data["data"] + + assert isinstance(after_payload, list) + + after_ids = {c["id"] for c in after_payload} + assert str(collection.id) in after_ids + + for row in after_payload: + assert row["project_id"] == project.id + + +def test_list_collections_includes_vector_store_collection_with_fields( + db: Session, + client: TestClient, + user_api_key_header, +): + """ + Ensure that vector-store-style collections created via get_vector_store_collection + appear in the list and expose the expected LLM fields. + """ + project = get_project(db, "Dalgo") + collection = get_vector_store_collection(db, project) + + response = client.get( + f"{settings.API_V1_STR}/collections/", + headers=user_api_key_header, + ) + assert response.status_code == 200 + + data = response.json() + assert data["success"] is True + + rows = data["data"] + assert isinstance(rows, list) + + matching = [c for c in rows if c["id"] == str(collection.id)] + assert matching + + row = matching[0] + assert row["project_id"] == project.id + assert row["llm_service_name"] == "openai vector store" + assert row["llm_service_id"] == collection.llm_service_id + + +def test_list_collections_does_not_error_with_no_collections( + db: Session, + client: TestClient, + user_api_key_header, +): + """ + If the project has no collections yet, the endpoint should still return + 200 and an empty list (or at least a list). + This assumes a clean DB or that there may be zero collections initially. + """ + response = client.get( + f"{settings.API_V1_STR}/collections/", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + + data = response.json() + assert data["success"] is True + assert isinstance(data["data"], list) diff --git a/backend/app/tests/api/routes/collections/test_create_collections.py b/backend/app/tests/api/routes/collections/test_create_collections.py index 2b5d786bb..0e6e1c5ba 100644 --- a/backend/app/tests/api/routes/collections/test_create_collections.py +++ b/backend/app/tests/api/routes/collections/test_create_collections.py @@ -1,49 +1,128 @@ -from uuid import UUID +from uuid import UUID, uuid4 from unittest.mock import patch from fastapi.testclient import TestClient -from unittest.mock import patch -from app.models.collection import Collection, CreationRequest +from app.core.config import settings +from app.models import CollectionJobStatus +from app.models.collection import CreationRequest + + +def _extract_metadata(body: dict) -> dict | None: + return body.get("metadata") or body.get("meta") -def test_collection_creation_success( - client: TestClient, user_api_key_header: dict[str, str], user_api_key +@patch("app.api.routes.collections.create_service.start_job") +def test_collection_creation_with_assistant_calls_start_job_and_returns_job( + mock_start_job, + client: TestClient, + user_api_key_header: dict[str, str], + user_api_key, ): - with patch("app.api.routes.collections.create_service.start_job") as mock_job_start: - creation_data = CreationRequest( - model="gpt-4o", - instructions="string", - temperature=0.000001, - documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], - batch_size=1, - callback_url=None, - ) - - resp = client.post( - "/api/v1/collections/create", - json=creation_data.model_dump(mode="json"), - headers=user_api_key_header, - ) - - assert resp.status_code == 200 - body = resp.json() - - data = body["data"] - assert isinstance(data, dict) - assert data["action_type"] == "CREATE" - assert data["status"] == "PENDING" - assert data["project_id"] == user_api_key.project_id - assert data["collection_id"] is None - assert data["task_id"] is None - assert "trace_id" in data - assert data["inserted_at"] - assert data["updated_at"] - - job_key = data["id"] - - mock_job_start.assert_called_once() - kwargs = mock_job_start.call_args.kwargs - assert "db" in kwargs - assert kwargs["request"] == creation_data - assert kwargs["collection_job_id"] == UUID(job_key) + creation_data = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], + batch_size=1, + callback_url=None, + ) + + resp = client.post( + f"{settings.API_V1_STR}/collections", + json=creation_data.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert data["status"] == CollectionJobStatus.PENDING + assert data["job_inserted_at"] + assert data["job_updated_at"] + + assert _extract_metadata(body) in (None, {}) + + mock_start_job.assert_called_once() + kwargs = mock_start_job.call_args.kwargs + assert "db" in kwargs + assert kwargs["project_id"] == user_api_key.project_id + assert kwargs["organization_id"] == user_api_key.organization_id + assert kwargs["with_assistant"] is True + + returned_job_id = UUID(data["job_id"]) + assert kwargs["collection_job_id"] == returned_job_id + + assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( + mode="json" + ) + + +@patch("app.api.routes.collections.create_service.start_job") +def test_collection_creation_vector_only_adds_metadata_and_sets_with_assistant_false( + mock_start_job, + client: TestClient, + user_api_key_header: dict[str, str], + user_api_key, +): + creation_data = CreationRequest( + temperature=0.000001, + documents=[str(uuid4())], + batch_size=1, + callback_url=None, + ) + + resp = client.post( + f"{settings.API_V1_STR}/collections", + json=creation_data.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert resp.status_code == 200 + body = resp.json() + + data = body["data"] + assert data["status"] == CollectionJobStatus.PENDING + + meta = _extract_metadata(body) + assert isinstance(meta, dict) + assert "vector store only" in meta.get("note", "").lower() + + mock_start_job.assert_called_once() + kwargs = mock_start_job.call_args.kwargs + assert kwargs["project_id"] == user_api_key.project_id + assert kwargs["organization_id"] == user_api_key.organization_id + assert kwargs["with_assistant"] is False + assert kwargs["collection_job_id"] == UUID(data["job_id"]) + assert kwargs["request"].model_dump(mode="json") == creation_data.model_dump( + mode="json" + ) + + +def test_collection_creation_vector_only_request_validation_error( + client: TestClient, user_api_key_header: dict[str, str] +): + payload = { + "model": "gpt-4o", + "temperature": 0.000001, + "documents": [str(uuid4())], + "batch_size": 1, + "callback_url": None, + } + + resp = client.post( + f"{settings.API_V1_STR}/collections", + json=payload, + headers=user_api_key_header, + ) + + assert resp.status_code == 422 + body = resp.json() + assert body["success"] is False + assert body["data"] is None + assert body["metadata"] is None + assert ( + "To create an Assistant, provide BOTH 'model' and 'instructions'" + in body["error"] + ) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py index 925f595e8..fc52cd086 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_create.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_create.py @@ -1,11 +1,12 @@ +from uuid import uuid4 + import openai_responses from sqlmodel import Session, select from app.crud import CollectionCrud -from app.models import DocumentCollection +from app.models import DocumentCollection, Collection from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection class TestCollectionCreate: @@ -14,7 +15,14 @@ class TestCollectionCreate: @openai_responses.mock() def test_create_associates_documents(self, db: Session): project = get_project(db) - collection = get_collection(db, project_id=project.id) + collection = Collection( + id=uuid4(), + project_id=project.id, + organization_id=project.organization_id, + llm_service_id="asst_dummy", + llm_service_name="gpt-4o", + ) + store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(self._n_documents) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py index e151a1c6a..a2668b19c 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_delete.py @@ -1,15 +1,37 @@ import pytest + import openai_responses from openai import OpenAI from sqlmodel import Session, select from app.core.config import settings from app.crud import CollectionCrud -from app.models import APIKey +from app.models import APIKey, Collection from app.crud.rag import OpenAIAssistantCrud from app.tests.utils.utils import get_project from app.tests.utils.document import DocumentStore -from app.tests.utils.collection import get_collection, uuid_increment + + +def get_collection_for_delete( + db: Session, client=None, project_id: int = None +) -> Collection: + project = get_project(db) + if client is None: + client = OpenAI(api_key="test_api_key") + + vector_store = client.vector_stores.create() + assistant = client.beta.assistants.create( + model="gpt-4o", + tools=[{"type": "file_search"}], + tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, + ) + + return Collection( + organization_id=project.organization_id, + project_id=project_id, + llm_service_id=assistant.id, + llm_service_name="gpt-4o", + ) class TestCollectionDelete: @@ -21,7 +43,7 @@ def test_delete_marks_deleted(self, db: Session): client = OpenAI(api_key="sk-test-key") assistant = OpenAIAssistantCrud(client) - collection = get_collection(db, client, project_id=project.id) + collection = get_collection_for_delete(db, client, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) @@ -34,26 +56,13 @@ def test_delete_follows_insert(self, db: Session): assistant = OpenAIAssistantCrud(client) project = get_project(db) - collection = get_collection(db, project_id=project.id) + collection = get_collection_for_delete(db, project_id=project.id) crud = CollectionCrud(db, collection.project_id) collection_ = crud.delete(collection, assistant) assert collection_.inserted_at <= collection_.deleted_at - @openai_responses.mock() - def test_cannot_delete_others_collections(self, db: Session): - client = OpenAI(api_key="sk-test-key") - - assistant = OpenAIAssistantCrud(client) - project = get_project(db) - collection = get_collection(db, project_id=project.id) - c_id = uuid_increment(collection.id) - - crud = CollectionCrud(db, c_id) - with pytest.raises(PermissionError): - crud.delete(collection, assistant) - @openai_responses.mock() def test_delete_document_deletes_collections(self, db: Session): project = get_project(db) @@ -68,7 +77,7 @@ def test_delete_document_deletes_collections(self, db: Session): client = OpenAI(api_key="sk-test-key") resources = [] for _ in range(self._n_collections): - coll = get_collection(db, client, project_id=project.id) + coll = get_collection_for_delete(db, client, project_id=project.id) crud = CollectionCrud(db, project_id=project.id) collection = crud.create(coll, documents) resources.append((crud, collection)) diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py index d1f329a2a..a9da35235 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_all.py @@ -1,4 +1,5 @@ import pytest + from openai_responses import OpenAIMock from openai import OpenAI from sqlmodel import Session @@ -17,7 +18,7 @@ def create_collections(db: Session, n: int): with openai_mock.router: client = OpenAI(api_key="sk-test-key") for _ in range(n): - collection = get_collection(db, client, project_id=project.id) + collection = get_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) if crud is None: diff --git a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py index acf7d39ad..ceb46c1a3 100644 --- a/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py +++ b/backend/app/tests/crud/collections/collection/test_crud_collection_read_one.py @@ -6,10 +6,9 @@ from sqlmodel import Session from app.crud import CollectionCrud -from app.core.config import settings from app.tests.utils.document import DocumentStore from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection, uuid_increment +from app.tests.utils.collection import get_collection def mk_collection(db: Session): @@ -17,7 +16,7 @@ def mk_collection(db: Session): project = get_project(db) with openai_mock.router: client = OpenAI(api_key="sk-test-key") - collection = get_collection(db, client, project_id=project.id) + collection = get_collection(db, project=project) store = DocumentStore(db, project_id=collection.project_id) documents = store.fill(1) crud = CollectionCrud(db, collection.project_id) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index 430e7b4be..9d5e7e97d 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -1,8 +1,9 @@ import os import pytest from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, MagicMock from urllib.parse import urlparse +import uuid from uuid import UUID, uuid4 from moto import mock_aws @@ -12,10 +13,11 @@ from app.core.config import settings from app.crud import CollectionCrud, CollectionJobCrud, DocumentCollectionCrud from app.models import CollectionJobStatus, CollectionJob, CollectionActionType -from app.models.collection import CreationRequest, ResponsePayload +from app.models.collection import CreationRequest from app.services.collections.create_collection import start_job, execute_job from app.tests.utils.openai import get_mock_openai_client_with_vector_store from app.tests.utils.utils import get_project +from app.tests.utils.collection import get_collection_job, get_collection from app.tests.utils.document import DocumentStore @@ -61,11 +63,16 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): batch_size=1, callback_url=None, ) - route = "/collections/create" - payload = ResponsePayload(status="processing", route=route) job_id = uuid4() - _ = create_collection_job_for_create(db, project, job_id) + _ = get_collection_job( + db, + project, + job_id=job_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) with patch( "app.services.collections.create_collection.start_low_priority_job" @@ -76,8 +83,8 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): db=db, request=request, project_id=project.id, - payload=payload, collection_job_id=job_id, + with_assistant=True, organization_id=project.organization_id, ) @@ -102,10 +109,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): assert kwargs["project_id"] == project.id assert kwargs["organization_id"] == project.organization_id assert kwargs["job_id"] == str(job_id) - assert kwargs["request"] == request.model_dump() - - passed_payload = kwargs.get("payload", kwargs.get("payload_data")) - assert passed_payload == payload.model_dump() + assert kwargs["request"] == request.model_dump(mode="json") @pytest.mark.usefixtures("aws_credentials") @@ -141,20 +145,18 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( batch_size=1, callback_url=None, ) - sample_payload = ResponsePayload(status="pending", route="/test/route") mock_client = get_mock_openai_client_with_vector_store() mock_get_openai_client.return_value = mock_client job_id = uuid4() - job_crud = CollectionJobCrud(db, project.id) - job_crud.create( - CollectionJob( - id=job_id, - project_id=project.id, - status=CollectionJobStatus.PENDING, - action_type=CollectionActionType.CREATE.value, - ) + _ = get_collection_job( + db, + project, + job_id=job_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, ) task_id = uuid4() @@ -165,10 +167,10 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( execute_job( request=sample_request.model_dump(), - payload=sample_payload.model_dump(), project_id=project.id, organization_id=project.organization_id, task_id=str(task_id), + with_assistant=True, job_id=str(job_id), task_instance=None, ) @@ -188,3 +190,306 @@ def test_execute_job_success_flow_updates_job_and_creates_collection( docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) assert len(docs) == 1 assert docs[0].fname == document.fname + + +@pytest.mark.usefixtures("aws_credentials") +@mock_aws +@patch("app.services.collections.create_collection.get_openai_client") +def test_execute_job_assistant_create_failure_marks_failed_and_deletes_vector( + mock_get_openai_client, db +): + project = get_project(db) + + job = get_collection_job( + db, + project, + job_id=uuid4(), + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + req = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.0, + documents=[], + batch_size=1, + callback_url=None, + ) + + _ = mock_get_openai_client.return_value + + with patch( + "app.services.collections.create_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.create_collection.OpenAIVectorStoreCrud" + ) as MockVS, patch( + "app.services.collections.create_collection.OpenAIAssistantCrud" + ) as MockAsst: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + MockVS.return_value.create.return_value = type( + "Vector store", (), {"id": "vs_123"} + )() + MockVS.return_value.update.return_value = [] + + MockAsst.return_value.create.side_effect = RuntimeError("assistant boom") + + task_id = str(uuid4()) + execute_job( + request=req.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=task_id, + with_assistant=True, + job_id=str(job.id), + task_instance=None, + ) + + failed = CollectionJobCrud(db, project.id).read_one(job.id) + assert failed.task_id == task_id + assert failed.status == CollectionJobStatus.FAILED + assert "assistant boom" in (failed.error_message or "") + + MockVS.return_value.delete.assert_called_once_with("vs_123") + + +@pytest.mark.usefixtures("aws_credentials") +@mock_aws +@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.send_callback") +def test_execute_job_success_flow_callback_job_and_creates_collection( + mock_send_callback, + mock_get_openai_client, + db, +): + """ + execute_job should: + - set task_id on the CollectionJob + - ingest documents into a vector store + - create an OpenAI assistant + - create a Collection with llm fields filled + - link the CollectionJob -> collection_id, set status=successful + - create DocumentCollection links + """ + project = get_project(db) + + aws = AmazonCloudStorageClient() + aws.create() + + store = DocumentStore(db=db, project_id=project.id) + document = store.put() + s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") + aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + + callback_url = "https://example.com/collections/create-success" + + sample_request = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[document.id], + batch_size=1, + callback_url=callback_url, + ) + + mock_client = get_mock_openai_client_with_vector_store() + mock_get_openai_client.return_value = mock_client + + job_id = uuid.uuid4() + _ = get_collection_job( + db, + project, + job_id=job_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + task_id = uuid.uuid4() + + with patch("app.services.collections.create_collection.Session") as SessionCtor: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + mock_send_callback.return_value = MagicMock(status_code=403) + + execute_job( + request=sample_request.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + with_assistant=True, + job_id=str(job_id), + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job_id) + collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is True + assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL + assert payload_arg["data"]["collection"]["id"] == str(collection.id) + assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + + +@pytest.mark.usefixtures("aws_credentials") +@mock_aws +@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.send_callback") +def test_execute_job_success_creates_collection_with_callback( + mock_send_callback, + mock_get_openai_client, + db, +): + """ + execute_job should: + - set task_id on the CollectionJob + - ingest documents into a vector store + - create an OpenAI assistant + - create a Collection with llm fields filled + - link the CollectionJob -> collection_id, set status=successful + - create DocumentCollection links + """ + project = get_project(db) + + aws = AmazonCloudStorageClient() + aws.create() + + store = DocumentStore(db=db, project_id=project.id) + document = store.put() + s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") + aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + + callback_url = "https://example.com/collections/create-success" + + sample_request = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[document.id], + batch_size=1, + callback_url=callback_url, + ) + + mock_client = get_mock_openai_client_with_vector_store() + mock_get_openai_client.return_value = mock_client + + job_id = uuid.uuid4() + _ = get_collection_job( + db, + project, + job_id=job_id, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + task_id = uuid.uuid4() + + with patch("app.services.collections.create_collection.Session") as SessionCtor: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + mock_send_callback.return_value = MagicMock(status_code=403) + + execute_job( + request=sample_request.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + with_assistant=True, + job_id=str(job_id), + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job_id) + collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is True + assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL + assert payload_arg["data"]["collection"]["id"] == str(collection.id) + assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + + +@pytest.mark.usefixtures("aws_credentials") +@mock_aws +@patch("app.services.collections.create_collection.get_openai_client") +@patch("app.services.collections.create_collection.send_callback") +@patch("app.services.collections.create_collection.CollectionCrud") +def test_execute_job_failure_flow_callback_job_and_marks_failed( + MockCollectionCrud, + mock_send_callback, + mock_get_openai_client, + db: Session, +): + """ + When creation fails, the job should be marked as FAILED, an error should be logged, + and a failure callback with the error message should be triggered. + """ + project = get_project(db) + + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + mock_get_openai_client.return_value = MagicMock() + + callback_url = "https://example.com/collections/create-failure" + + collection_crud_instance = MockCollectionCrud.return_value + collection_crud_instance.read_one.return_value = collection + + sample_request = CreationRequest( + model="gpt-4o", + instructions="string", + temperature=0.000001, + documents=[uuid.uuid4()], + batch_size=1, + callback_url=callback_url, + ) + + task_id = uuid.uuid4() + + with patch("app.services.collections.create_collection.Session") as SessionCtor: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + execute_job( + request=sample_request.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + with_assistant=True, + job_id=str(job.id), + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + + assert updated_job.status == CollectionJobStatus.FAILED + assert "Requested atleast 1 document retrieved 0" in ( + updated_job.error_message or "" + ) + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is False + assert "Requested atleast 1 document retrieved 0" in (payload_arg["error"] or "") + assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED + assert payload_arg["data"]["collection"] is None + assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index f6f55c6ad..07e2af085 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -1,63 +1,28 @@ from unittest.mock import patch, MagicMock from uuid import uuid4, UUID -from sqlmodel import Session from sqlalchemy.exc import SQLAlchemyError from app.models.collection import ( DeletionRequest, - Collection, - ResponsePayload, ) from app.tests.utils.utils import get_project -from app.crud import CollectionCrud, CollectionJobCrud -from app.models import CollectionJobStatus, CollectionJob, CollectionActionType +from app.crud import CollectionJobCrud +from app.models import CollectionJobStatus, CollectionActionType +from app.tests.utils.collection import get_collection, get_collection_job from app.services.collections.delete_collection import start_job, execute_job -def create_collection(db: Session, project): - collection = Collection( - id=uuid4(), - project_id=project.id, - organization_id=project.organization_id, - llm_service_id="asst-nasjnl", - llm_service_name="gpt-4o", - ) - return CollectionCrud(db, project.id).create(collection) - - -def create_collection_job( - db: Session, - project, - collection, - job_id: UUID | None = None, -): - if job_id is None: - job_id = uuid4() - job_crud = CollectionJobCrud(db, project.id) - return job_crud.create( - CollectionJob( - id=job_id, - action_type=CollectionActionType.DELETE, - project_id=project.id, - collection_id=collection.id, - status=CollectionJobStatus.PENDING, - ) - ) - - -def test_start_job_creates_collection_job_and_schedules_task(db: Session): +def test_start_job_creates_collection_job_and_schedules_task(db): """ - - start_job should update an existing CollectionJob (status=processing, action=delete) + - start_job should update an existing CollectionJob (status=PENDING, action=DELETE) - schedule the task with the provided job_id and collection_id - - return the same job_id (string) + - return the same job_id (UUID) """ project = get_project(db) - created_collection = create_collection(db, project) + created_collection = get_collection(db, project) req = DeletionRequest(collection_id=created_collection.id) - route = "/collections/delete" - payload = ResponsePayload(status="processing", route=route) with patch( "app.services.collections.delete_collection.start_low_priority_job" @@ -65,20 +30,20 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): mock_schedule.return_value = "fake-task-id" collection_job_id = uuid4() - precreated = create_collection_job( - db=db, - project=project, - collection=created_collection, + _ = get_collection_job( + db, + project, job_id=collection_job_id, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=created_collection.id, ) returned = start_job( db=db, request=req, - collection=created_collection, project_id=project.id, collection_job_id=collection_job_id, - payload=payload, organization_id=project.organization_id, ) @@ -103,25 +68,30 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session): assert kwargs["organization_id"] == project.organization_id assert kwargs["job_id"] == str(job.id) assert kwargs["collection_id"] == str(created_collection.id) - assert kwargs["request"] == req.model_dump() - assert kwargs["payload"] == payload.model_dump() + assert kwargs["request"] == req.model_dump(mode="json") assert "trace_id" in kwargs @patch("app.services.collections.delete_collection.get_openai_client") def test_execute_job_delete_success_updates_job_and_calls_delete( - mock_get_openai_client, db: Session + mock_get_openai_client, db ): """ - execute_job should set task_id on the CollectionJob - - call CollectionCrud.delete(collection, assistant_crud) + - call remote delete via OpenAIAssistantCrud.delete(...) + - delete local record via CollectionCrud.delete_by_id(...) - mark job successful and clear error_message """ project = get_project(db) - collection = create_collection(db, project) - - job = create_collection_job(db, project, collection) + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) mock_get_openai_client.return_value = MagicMock() @@ -138,25 +108,18 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - deletion_result = MagicMock() - deletion_result.model_dump.return_value = { - "id": str(collection.id), - "deleted": True, - } - collection_crud_instance.delete.return_value = deletion_result + MockAssistantCrud.return_value.delete.return_value = None task_id = uuid4() req = DeletionRequest(collection_id=collection.id) - payload = ResponsePayload(status="processing", route="/test/delete") execute_job( - request=req.model_dump(), - payload=payload.model_dump(), + request=req.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, task_id=str(task_id), job_id=str(job.id), - collection_id=collection.id, + collection_id=str(collection.id), task_instance=None, ) @@ -167,26 +130,30 @@ def test_execute_job_delete_success_updates_job_and_calls_delete( MockCollectionCrud.assert_called_with(db, project.id) collection_crud_instance.read_one.assert_called_once_with(collection.id) - collection_crud_instance.delete.assert_called_once() - args, kwargs = collection_crud_instance.delete.call_args - assert isinstance(args[0], Collection) + MockAssistantCrud.assert_called_once() + MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") + + collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) mock_get_openai_client.assert_called_once() @patch("app.services.collections.delete_collection.get_openai_client") -def test_execute_job_delete_failure_marks_job_failed( - mock_get_openai_client, db: Session -): +def test_execute_job_delete_failure_marks_job_failed(mock_get_openai_client, db): """ - When CollectionCrud.delete raises (e.g., SQLAlchemyError), - the job should be marked failed and error_message set. + When the remote delete (OpenAIAssistantCrud.delete) raises, + the job should be marked FAILED and error_message set. """ project = get_project(db) - collection = create_collection(db, project) - - job = create_collection_job(db, project, collection) + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) mock_get_openai_client.return_value = MagicMock() @@ -202,15 +169,16 @@ def test_execute_job_delete_failure_marks_job_failed( collection_crud_instance = MockCollectionCrud.return_value collection_crud_instance.read_one.return_value = collection - collection_crud_instance.delete.side_effect = SQLAlchemyError("boom") + + MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( + "something went wrong" + ) task_id = uuid4() req = DeletionRequest(collection_id=collection.id) - payload = ResponsePayload(status="processing", route="/test/delete") execute_job( - request=req.model_dump(), - payload=payload.model_dump(), + request=req.model_dump(mode="json"), project_id=project.id, organization_id=project.organization_id, task_id=str(task_id), @@ -222,7 +190,183 @@ def test_execute_job_delete_failure_marks_job_failed( failed_job = CollectionJobCrud(db, project.id).read_one(job.id) assert failed_job.task_id == str(task_id) assert failed_job.status == CollectionJobStatus.FAILED - assert failed_job.error_message and "boom" in failed_job.error_message + assert ( + failed_job.error_message + and "something went wrong" in failed_job.error_message + ) + + MockCollectionCrud.assert_called_with(db, project.id) + collection_crud_instance.read_one.assert_called_once_with(collection.id) MockAssistantCrud.assert_called_once() + MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") + + collection_crud_instance.delete_by_id.assert_not_called() + mock_get_openai_client.assert_called_once() + + +@patch("app.services.collections.delete_collection.get_openai_client") +def test_execute_job_delete_success_with_callback_sends_success_payload( + mock_get_openai_client, + db, +): + """ + When deletion succeeds and a callback_url is provided: + - job is marked SUCCESSFUL + - send_callback is called once + - success payload has success=True, status=SUCCESSFUL, and correct collection id + """ + project = get_project(db) + + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) + + mock_get_openai_client.return_value = MagicMock() + + callback_url = "https://example.com/collections/delete-success" + + with patch( + "app.services.collections.delete_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.delete_collection.OpenAIAssistantCrud" + ) as MockAssistantCrud, patch( + "app.services.collections.delete_collection.CollectionCrud" + ) as MockCollectionCrud, patch( + "app.services.collections.delete_collection.send_callback" + ) as mock_send_callback: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + collection_crud_instance = MockCollectionCrud.return_value + collection_crud_instance.read_one.return_value = collection + + MockAssistantCrud.return_value.delete.return_value = None + + task_id = uuid4() + req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) + + from app.services.collections.delete_collection import execute_job + + execute_job( + request=req.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + job_id=str(job.id), + collection_id=str(collection.id), + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.task_id == str(task_id) + assert updated_job.status == CollectionJobStatus.SUCCESSFUL + assert updated_job.error_message in (None, "") + MockCollectionCrud.assert_called_with(db, project.id) + collection_crud_instance.read_one.assert_called_once_with(collection.id) + MockAssistantCrud.assert_called_once() + MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") + collection_crud_instance.delete_by_id.assert_called_once_with(collection.id) + mock_get_openai_client.assert_called_once() + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is True + assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL + assert payload_arg["data"]["collection"]["id"] == str(collection.id) + assert UUID(payload_arg["data"]["job_id"]) == job.id + + +@patch("app.services.collections.delete_collection.get_openai_client") +def test_execute_job_delete_remote_failure_with_callback_sends_failure_payload( + mock_get_openai_client, + db, +): + """ + When the remote delete raises AND a callback_url is provided: + - job is marked FAILED with error_message set + - send_callback is called once + - failure payload has success=False, status=FAILED, correct collection id, and error message + """ + project = get_project(db) + + collection = get_collection(db, project, assistant_id="asst_123") + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) + + mock_get_openai_client.return_value = MagicMock() + callback_url = "https://example.com/collections/delete-failed" + + with patch( + "app.services.collections.delete_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.delete_collection.OpenAIAssistantCrud" + ) as MockAssistantCrud, patch( + "app.services.collections.delete_collection.CollectionCrud" + ) as MockCollectionCrud, patch( + "app.services.collections.delete_collection.send_callback" + ) as mock_send_callback: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + collection_crud_instance = MockCollectionCrud.return_value + collection_crud_instance.read_one.return_value = collection + + MockAssistantCrud.return_value.delete.side_effect = SQLAlchemyError( + "something went wrong" + ) + + task_id = uuid4() + req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) + + from app.services.collections.delete_collection import execute_job + + execute_job( + request=req.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(task_id), + job_id=str(job.id), + collection_id=str(collection.id), + task_instance=None, + ) + + failed_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert failed_job.task_id == str(task_id) + assert failed_job.status == CollectionJobStatus.FAILED + assert ( + failed_job.error_message + and "something went wrong" in failed_job.error_message + ) + + MockCollectionCrud.assert_called_with(db, project.id) + collection_crud_instance.read_one.assert_called_once_with(collection.id) + + MockAssistantCrud.assert_called_once() + MockAssistantCrud.return_value.delete.assert_called_once_with("asst_123") + + collection_crud_instance.delete_by_id.assert_not_called() + mock_get_openai_client.assert_called_once() + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is False + assert "something went wrong" in (payload_arg["error"] or "") + assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED + assert payload_arg["data"]["collection"]["id"] == str(collection.id) + assert UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py new file mode 100644 index 000000000..8bb04f02b --- /dev/null +++ b/backend/app/tests/services/collections/test_helpers.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from uuid import uuid4 + +from app.services.collections import helpers + + +def test_extract_error_message_parses_json_and_strips_prefix(): + payload = {"error": {"message": "Inner JSON message"}} + err = Exception(f"Error code: 400 - {json.dumps(payload)}") + msg = helpers.extract_error_message(err) + assert msg == "Inner JSON message" + + +def test_extract_error_message_parses_python_dict_repr(): + payload = {"error": {"message": "Dict-repr message"}} + err = Exception(str(payload)) + msg = helpers.extract_error_message(err) + assert msg == "Dict-repr message" + + +def test_extract_error_message_falls_back_to_clean_text_and_truncates(): + long_text = "x" * 1500 + err = Exception(long_text) + msg = helpers.extract_error_message(err) + assert len(msg) == 1000 + assert msg == long_text[:1000] + + +def test_extract_error_message_handles_non_matching_bodies(): + err = Exception("some random error without structure") + msg = helpers.extract_error_message(err) + assert msg == "some random error without structure" + + +# batch documents + + +class FakeDocumentCrud: + def __init__(self): + self.calls = [] + + def read_each(self, ids): + self.calls.append(list(ids)) + return [ + SimpleNamespace( + id=i, fname=f"{i}.txt", object_store_url=f"s3://bucket/{i}.txt" + ) + for i in ids + ] + + +def test_batch_documents_even_chunks(): + crud = FakeDocumentCrud() + ids = [uuid4() for _ in range(6)] + batches = helpers.batch_documents(crud, ids, batch_size=3) + + # read_each called with chunks [0:3], [3:6] + assert crud.calls == [ids[0:3], ids[3:6]] + # output mirrors what read_each returned + assert len(batches) == 2 + assert [d.id for d in batches[0]] == ids[0:3] + assert [d.id for d in batches[1]] == ids[3:6] + + +def test_batch_documents_ragged_last_chunk(): + crud = FakeDocumentCrud() + ids = [uuid4() for _ in range(5)] + batches = helpers.batch_documents(crud, ids, batch_size=2) + + assert crud.calls == [ids[0:2], ids[2:4], ids[4:5]] + assert [d.id for d in batches[0]] == ids[0:2] + assert [d.id for d in batches[1]] == ids[2:4] + assert [d.id for d in batches[2]] == ids[4:5] + + +def test_batch_documents_empty_input(): + crud = FakeDocumentCrud() + batches = helpers.batch_documents(crud, [], batch_size=3) + assert batches == [] + assert crud.calls == [] + + +# _backout + + +def test_backout_calls_delete_and_swallows_openai_error(monkeypatch): + class Crud: + def __init__(self): + self.calls = 0 + + def delete(self, resource_id: str): + self.calls += 1 + + crud = Crud() + helpers._backout(crud, "rsrc_1") + assert crud.calls == 1 + + class DummyOpenAIError(Exception): + pass + + monkeypatch.setattr(helpers, "OpenAIError", DummyOpenAIError) + + class FailingCrud: + def delete(self, resource_id: str): + raise DummyOpenAIError("nope") + + helpers._backout(FailingCrud(), "rsrc_2") diff --git a/backend/app/tests/utils/collection.py b/backend/app/tests/utils/collection.py index 36e1ecf07..429bfc8b3 100644 --- a/backend/app/tests/utils/collection.py +++ b/backend/app/tests/utils/collection.py @@ -1,14 +1,15 @@ -from uuid import UUID -from uuid import uuid4 +from uuid import UUID, uuid4 +from typing import Optional -from openai import OpenAI from sqlmodel import Session -from app.core.config import settings -from app.models import Collection, Organization, Project -from app.tests.utils.utils import get_user_id_by_email, get_project -from app.tests.utils.test_data import create_test_project -from app.tests.utils.test_data import create_test_api_key +from app.models import ( + Collection, + CollectionActionType, + CollectionJob, + CollectionJobStatus, +) +from app.crud import CollectionCrud, CollectionJobCrud class constants: @@ -17,25 +18,78 @@ class constants: def uuid_increment(value: UUID): - inc = int(value) + 1 # hopefully doesn't overflow! + inc = int(value) + 1 return UUID(int=inc) -def get_collection(db: Session, client=None, project_id: int = None) -> Collection: - project = get_project(db) - if client is None: - client = OpenAI(api_key="test_api_key") +def get_collection( + db: Session, + project, + *, + assistant_id: Optional[str] = None, + model: str = "gpt-4o", + collection_id: Optional[UUID] = None, +) -> Collection: + """ + Create a Collection configured for the Assistant path. + execute_job will treat this as `is_vector = False` and use assistant id. + """ + if assistant_id is None: + assistant_id = f"asst_{uuid4().hex}" - vector_store = client.vector_stores.create() - assistant = client.beta.assistants.create( - model=constants.openai_model, - tools=[{"type": "file_search"}], - tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, + collection = Collection( + id=collection_id or uuid4(), + project_id=project.id, + organization_id=project.organization_id, + llm_service_name=model, + llm_service_id=assistant_id, ) + return CollectionCrud(db, project.id).create(collection) + + +def get_vector_store_collection( + db: Session, + project, + *, + vector_store_id: Optional[str] = None, + collection_id: Optional[UUID] = None, +) -> Collection: + """ + Create a Collection configured for the Vector Store path. + execute_job will treat this as `is_vector = True` and use vector store id. + """ + if vector_store_id is None: + vector_store_id = f"vs_{uuid4().hex}" - return Collection( + collection = Collection( + id=collection_id or uuid4(), + project_id=project.id, organization_id=project.organization_id, - project_id=project_id, - llm_service_id=assistant.id, - llm_service_name=constants.llm_service_name, + llm_service_name="openai vector store", + llm_service_id=vector_store_id, + ) + return CollectionCrud(db, project.id).create(collection) + + +def get_collection_job( + db: Session, + project, + *, + action_type: CollectionActionType = CollectionActionType.CREATE, + status: CollectionJobStatus = CollectionJobStatus.PENDING, + collection_id: Optional[UUID] = None, + error_message: Optional[str] = None, + job_id: Optional[UUID] = None, +) -> CollectionJob: + """ + Generic seed for a CollectionJob row. + """ + job = CollectionJob( + id=job_id or uuid4(), + project_id=project.id, + action_type=action_type.value if hasattr(action_type, "value") else action_type, + status=status.value if hasattr(status, "value") else status, + error_message=error_message, + collection_id=collection_id, ) + return CollectionJobCrud(db, project.id).create(job)