diff --git a/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py b/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py new file mode 100644 index 000000000..6bf4b97bf --- /dev/null +++ b/backend/app/alembic/versions/058_add_batch_tracking_to_collections_jobs.py @@ -0,0 +1,62 @@ +"""add batch tracking to collection_jobs + +Revision ID: 058 +Revises: 057 +Create Date: 2026-04-13 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "058" +down_revision = "057" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + "collection_jobs", + sa.Column( + "total_batches", + sa.Integer(), + nullable=True, + comment="Total number of batches the documents are split into", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "current_batch_number", + sa.Integer(), + nullable=True, + comment="Which batch is currently being processed (1-indexed)", + ), + ) + op.add_column( + "collection_jobs", + sa.Column( + "documents_uploaded", + sa.JSON(), + nullable=True, + comment="List of document IDs successfully uploaded so far", + ), + ) + op.add_column( + "document", + sa.Column( + "openai_file_id", + sa.String(), + nullable=True, + comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading", + ), + ) + + +def downgrade(): + op.drop_column("collection_jobs", "total_batches") + op.drop_column("collection_jobs", "current_batch_number") + op.drop_column("collection_jobs", "documents_uploaded") + op.drop_column("document", "openai_file_id") diff --git a/backend/app/api/docs/documents/upload.md b/backend/app/api/docs/documents/upload.md index e667015f5..438dc3e9b 100644 --- a/backend/app/api/docs/documents/upload.md +++ b/backend/app/api/docs/documents/upload.md @@ -1,6 +1,6 @@ Upload a document to Kaapi. -- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. +- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB. - If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job. - If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed. diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..9a13fddcf 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -4,9 +4,6 @@ import celery from asgi_correlation_id import correlation_id from celery import current_task -from opentelemetry import context as otel_context -from opentelemetry import trace -from opentelemetry.propagate import extract from app.celery.celery_app import celery_app from app.celery.utils import gevent_timeout @@ -20,61 +17,18 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") -def _extract_parent_context(task_instance) -> otel_context.Context: - """Extract OTel parent context from Celery headers if available.""" - headers = getattr(task_instance.request, "headers", None) or {} - carrier: dict[str, str] = {} - - if isinstance(headers, dict): - for key, value in headers.items(): - if isinstance(value, str): - carrier[str(key)] = value - - nested = headers.get("otel", {}) - if isinstance(nested, dict): - for key, value in nested.items(): - if isinstance(value, str): - carrier[str(key)] = value - - return extract(carrier) - - -def _run_with_otel_parent(task_instance, fn): - """Attach extracted parent context and execute function. - - When Celery auto-instrumentation is active, there is already a current - `run/...` span. Re-attaching extracted parent context here would make - service spans become siblings of `run/...` instead of children. - - We only attach extracted context as a fallback when no active span exists. - """ - current_ctx = trace.get_current_span().get_span_context() - if current_ctx and current_ctx.is_valid: - return fn() - - parent_ctx = _extract_parent_context(task_instance) - token = otel_context.attach(parent_ctx) - try: - return fn() - finally: - otel_context.detach(token) - - @celery_app.task(bind=True, queue="high_priority", priority=9) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job") def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -84,15 +38,12 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -102,15 +53,12 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs from app.services.response.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -120,15 +68,12 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw from app.services.doctransform.job import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -137,18 +82,32 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw def run_create_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_job + from app.services.collections.create_collection import execute_setup_job + + _set_trace(trace_id) + return execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ) + + +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") +def run_collection_batch_job( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -160,15 +119,12 @@ def run_delete_collection_job( from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -180,15 +136,12 @@ def run_stt_batch_submission( from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -200,15 +153,12 @@ def run_stt_metric_computation( from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -220,15 +170,12 @@ def run_tts_batch_submission( from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @@ -242,13 +189,10 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 288cba7c4..f39fb9d5d 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -18,24 +18,14 @@ F = TypeVar("F", bound=Callable[..., Any]) -def _enqueue_with_trace_context(task, **kwargs) -> str: - """Publish Celery task with explicit trace context headers.""" - otel_headers: dict[str, str] = {} - inject(otel_headers) - celery_headers = dict(otel_headers) - celery_headers["otel"] = otel_headers - async_result = task.apply_async(kwargs=kwargs, headers=celery_headers) - return async_result.id - - def start_llm_job(project_id: int, job_id: str, trace_id: str = "N/A", **kwargs) -> str: from app.celery.tasks.job_execution import run_llm_job - task_id = _enqueue_with_trace_context( - run_llm_job, project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + task = run_llm_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_llm_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_llm_chain_job( @@ -43,17 +33,13 @@ def start_llm_chain_job( ) -> str: from app.celery.tasks.job_execution import run_llm_chain_job - task_id = _enqueue_with_trace_context( - run_llm_chain_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_llm_chain_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_llm_chain_job] Started job {job_id} with Celery task {task_id}" + f"[start_llm_chain_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_response_job( @@ -61,15 +47,11 @@ def start_response_job( ) -> str: from app.celery.tasks.job_execution import run_response_job - task_id = _enqueue_with_trace_context( - run_response_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_response_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) - logger.info(f"[start_response_job] Started job {job_id} with Celery task {task_id}") - return task_id + logger.info(f"[start_response_job] Started job {job_id} with Celery task {task.id}") + return task.id def start_doctransform_job( @@ -77,17 +59,13 @@ def start_doctransform_job( ) -> str: from app.celery.tasks.job_execution import run_doctransform_job - task_id = _enqueue_with_trace_context( - run_doctransform_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_doctransform_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_doctransform_job] Started job {job_id} with Celery task {task_id}" + f"[start_doctransform_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_create_collection_job( @@ -95,17 +73,27 @@ def start_create_collection_job( ) -> str: from app.celery.tasks.job_execution import run_create_collection_job - task_id = _enqueue_with_trace_context( - run_create_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_create_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs + ) + logger.info( + f"[start_create_collection_job] Started job {job_id} with Celery task {task.id}" + ) + return task.id + + +def start_collection_batch_job( + project_id: int, job_id: str, trace_id: str = "N/A", **kwargs +) -> str: + from app.celery.tasks.job_execution import run_collection_batch_job + + task = run_collection_batch_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_create_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_collection_batch_job] Started batch job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_delete_collection_job( @@ -113,17 +101,13 @@ def start_delete_collection_job( ) -> str: from app.celery.tasks.job_execution import run_delete_collection_job - task_id = _enqueue_with_trace_context( - run_delete_collection_job, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_delete_collection_job.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_delete_collection_job] Started job {job_id} with Celery task {task_id}" + f"[start_delete_collection_job] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_batch_submission( @@ -131,17 +115,13 @@ def start_stt_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_stt_batch_submission - task_id = _enqueue_with_trace_context( - run_stt_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_stt_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_stt_metric_computation( @@ -149,17 +129,13 @@ def start_stt_metric_computation( ) -> str: from app.celery.tasks.job_execution import run_stt_metric_computation - task_id = _enqueue_with_trace_context( - run_stt_metric_computation, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_stt_metric_computation.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_stt_metric_computation] Started job {job_id} with Celery task {task_id}" + f"[start_stt_metric_computation] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_batch_submission( @@ -167,17 +143,13 @@ def start_tts_batch_submission( ) -> str: from app.celery.tasks.job_execution import run_tts_batch_submission - task_id = _enqueue_with_trace_context( - run_tts_batch_submission, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_batch_submission.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_batch_submission] Started job {job_id} with Celery task {task_id}" + f"[start_tts_batch_submission] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def start_tts_result_processing( @@ -185,17 +157,13 @@ def start_tts_result_processing( ) -> str: from app.celery.tasks.job_execution import run_tts_result_processing - task_id = _enqueue_with_trace_context( - run_tts_result_processing, - project_id=project_id, - job_id=job_id, - trace_id=trace_id, - **kwargs, + task = run_tts_result_processing.delay( + project_id=project_id, job_id=job_id, trace_id=trace_id, **kwargs ) logger.info( - f"[start_tts_result_processing] Started job {job_id} with Celery task {task_id}" + f"[start_tts_result_processing] Started job {job_id} with Celery task {task.id}" ) - return task_id + return task.id def get_task_status(task_id: str) -> Dict[str, Any]: diff --git a/backend/app/crud/rag/open_ai.py b/backend/app/crud/rag/open_ai.py index cdae82440..07a9e671c 100644 --- a/backend/app/crud/rag/open_ai.py +++ b/backend/app/crud/rag/open_ai.py @@ -1,13 +1,11 @@ import json import logging import functools as ft -from io import BytesIO -from typing import Iterable +import time from openai import OpenAI, OpenAIError from pydantic import BaseModel -from app.core.cloud import CloudStorage from app.models import Document logger = logging.getLogger(__name__) @@ -78,11 +76,6 @@ def clean(self, resource): class VectorStoreCleaner(ResourceCleaner): def clean(self, resource): - logger.info( - f"[VectorStoreCleaner.clean] Starting vector store cleanup | {{'vector_store_id': '{resource}'}}" - ) - for i in vs_ls(self.client, resource): - self.client.files.delete(i.id) logger.info( f"[VectorStoreCleaner.clean] Deleting vector store | {{'vector_store_id': '{resource}'}}" ) @@ -118,36 +111,33 @@ def read(self, vector_store_id: str): def update( self, vector_store_id: str, - storage: CloudStorage, - documents: Iterable[Document], - ): - for docs in documents: - files = [] - for d in docs: - # Get file bytes and wrap in BytesIO for OpenAI API - content = storage.get(d.object_store_url) - f_obj = BytesIO(content) - f_obj.name = d.fname - files.append(f_obj) + docs: list[Document], + ) -> None: + if not docs: + return - logger.info( - f"[OpenAIVectorStoreCrud.update] Uploading files to vector store | {{'vector_store_id': '{vector_store_id}', 'file_count': {len(files)}}}" - ) - req = self.client.vector_stores.file_batches.upload_and_poll( + try: + batch = self.client.vector_stores.file_batches.upload_and_poll( vector_store_id=vector_store_id, - files=files, + files=[], + file_ids=[doc.openai_file_id for doc in docs], ) logger.info( - f"[OpenAIVectorStoreCrud.update] File upload completed | {{'vector_store_id': '{vector_store_id}', 'completed_files': {req.file_counts.completed}, 'total_files': {req.file_counts.total}}}" + f"[OpenAIVectorStoreCrud.update] Batch complete | " + f"{{'vector_store_id': '{vector_store_id}', " + f"'completed': {batch.file_counts.completed}, 'failed': {batch.file_counts.failed}}}" ) - if req.file_counts.completed != req.file_counts.total: - error_msg = f"OpenAI document processing error: {req.file_counts.completed}/{req.file_counts.total} files completed" - logger.error( - f"[OpenAIVectorStoreCrud.update] Document processing error | {{'vector_store_id': '{vector_store_id}', 'completed_files': {req.file_counts.completed}, 'total_files': {req.file_counts.total}}}" + if batch.file_counts.failed > 0: + logger.warning( + f"[OpenAIVectorStoreCrud.update] Batch had failures | " + f"{{'vector_store_id': '{vector_store_id}', 'failed_count': {batch.file_counts.failed}}}" ) - raise InterruptedError(error_msg) - - yield from docs + except OpenAIError as err: + logger.error( + f"[OpenAIVectorStoreCrud.update] Batch attach failed | " + f"{{'vector_store_id': '{vector_store_id}', 'error': '{str(err)}'}}", + exc_info=True, + ) def delete(self, vector_store_id: str, retries: int = 3): if retries < 1: diff --git a/backend/app/models/collection_job.py b/backend/app/models/collection_job.py index 333ebfd14..6b628ad7e 100644 --- a/backend/app/models/collection_job.py +++ b/backend/app/models/collection_job.py @@ -77,7 +77,29 @@ class CollectionJob(SQLModel, table=True): documents: list[str] | None = Field( default=None, sa_column=Column( - JSON, nullable=True, comment="List of documents given to make collection" + JSON, nullable=True, comment="List of document IDs given to make collection" + ), + ) + total_batches: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Total number of batches the documents are split into" + }, + ) + current_batch_number: int | None = Field( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Which batch is currently being processed (1-indexed)" + }, + ) + documents_uploaded: list[str] | None = Field( + default=None, + sa_column=Column( + JSON, + nullable=True, + comment="List of document IDs successfully uploaded so far", ), ) @@ -139,6 +161,9 @@ class CollectionJobUpdate(SQLModel): collection_id: UUID | None = None total_size_mb: float | None = None trace_id: str | None = None + total_batches: int | None = None + current_batch_number: int | None = None + documents_uploaded: list[str] | None = None ##Response models diff --git a/backend/app/models/document.py b/backend/app/models/document.py index 12843e72a..5bbcddc77 100644 --- a/backend/app/models/document.py +++ b/backend/app/models/document.py @@ -46,6 +46,11 @@ class Document(DocumentBase, table=True): description="The size of the document in kilobytes", sa_column_kwargs={"comment": "Size of the document in kilobytes (KB)"}, ) + openai_file_id: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "File ID assigned by OpenAI (avoid re-uploading)"}, + ) # Foreign keys source_document_id: UUID | None = Field( diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index a9b787f6b..ca3197200 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -2,10 +2,10 @@ import time from uuid import UUID, uuid4 -from opentelemetry import trace from sqlmodel import Session -from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded +from opentelemetry import trace from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage @@ -26,12 +26,13 @@ CreationRequest, ) from app.services.collections.helpers import ( + batch_documents, extract_error_message, to_collection_public, ) from app.services.collections.providers.registry import get_llm_provider -from app.celery.utils import start_create_collection_job -from app.utils import send_callback, get_webhook_secret, APIResponse +from app.celery.utils import start_create_collection_job, start_collection_batch_job +from app.utils import send_callback, APIResponse, get_webhook_secret logger = logging.getLogger(__name__) @@ -46,49 +47,31 @@ def start_job( with_assistant: bool, organization_id: int, ) -> str: - with log_context( - tag="collection", - lifecycle="collection.create.start_job", - action="create", - collection_job_id=collection_job_id, - project_id=project_id, - organization_id=organization_id, - ): - trace_id = correlation_id.get() or "N/A" + trace_id = correlation_id.get() or "N/A" - job_crud = CollectionJobCrud(db, project_id) - collection_job = job_crud.update( - collection_job_id, CollectionJobUpdate(trace_id=trace_id) - ) + job_crud = CollectionJobCrud(db, project_id) + job_crud.update(collection_job_id, CollectionJobUpdate(trace_id=trace_id)) - task_id = start_create_collection_job( - project_id=project_id, - job_id=str(collection_job_id), - trace_id=trace_id, - request=request.model_dump(mode="json"), - with_assistant=with_assistant, - organization_id=organization_id, - ) + task_id = start_create_collection_job( + project_id=project_id, + job_id=str(collection_job_id), + trace_id=trace_id, + request=request.model_dump(mode="json"), + with_assistant=with_assistant, + organization_id=organization_id, + ) - logger.info( - "[create_collection.start_job] Job scheduled to create collection | " - f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" - ) + logger.info( + "[create_collection.start_job] Job scheduled to create collection | " + f"collection_job_id={collection_job_id}, project_id={project_id}, task_id={task_id}" + ) - return collection_job_id + return collection_job_id def build_success_payload( collection_job: CollectionJob, collection: Collection ) -> dict: - """ - { - "success": true, - "data": { job fields + full collection }, - "error": null, - "metadata": null - } - """ collection_public = to_collection_public(collection) collection_dict = collection_public.model_dump(mode="json", exclude_none=True) @@ -102,15 +85,6 @@ def build_success_payload( def build_failure_payload(collection_job: CollectionJob, error_message: str) -> dict: - """ - { - "success": false, - "data": { job fields, collection: null }, - "error": "something went wrong", - "metadata": null - } - """ - # ensure `collection` is explicitly null in the payload job_public = CollectionJobPublic.model_validate( collection_job, update={"collection": None}, @@ -144,7 +118,7 @@ def _mark_job_failed( ) return collection_job except Exception: - logger.warning("[create_collection.execute_job] Failed to mark job as FAILED") + logger.warning("[create_collection] Failed to mark job as FAILED") return None @@ -186,7 +160,7 @@ def _handle_job_failure( ) -def execute_job( +def execute_setup_job( request: dict, with_assistant: bool, project_id: int, @@ -196,26 +170,21 @@ def execute_job( task_instance, ) -> None: """ - Worker entrypoint scheduled by start_job. - Orchestrates: job state, provider init, collection creation, - optional assistant creation, collection persistence, linking, callbacks, and cleanup. + Phase 1: Fetch documents, create the vector store, split into batches, + update job state to PROCESSING, then queue the first batch task. """ - start_time = time.time() - collection_job = None - result = None creation_request = None - provider = None with log_context( tag="collection", - lifecycle="collection.create.execute_job", + lifecycle="collection.create.execute_setup_job", action="create", collection_job_id=job_id, task_id=task_id, project_id=project_id, organization_id=organization_id, - ), tracer.start_as_current_span("collections.create.execute_job") as span: + ), tracer.start_as_current_span("collections.create.execute_setup_job") as span: span.set_attribute("collection.job_id", str(job_id)) span.set_attribute("kaapi.project_id", project_id) span.set_attribute("kaapi.organization_id", organization_id) @@ -228,32 +197,179 @@ def execute_job( span.set_attribute("collection.provider", str(creation_request.provider)) job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" with Session(engine) as session: document_crud = DocumentCrud(session, project_id) flat_docs = document_crud.read_each(creation_request.documents) + storage = get_cloud_storage(session=session, project_id=project_id) + + provider = get_llm_provider( + session=session, + provider=creation_request.provider, + project_id=project_id, + organization_id=organization_id, + ) + + for doc in flat_docs: + session.expunge(doc) + + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + task_id=task_id, + status=CollectionJobStatus.PROCESSING, + ), + ) - file_exts = { - doc.fname.split(".")[-1] for doc in flat_docs if "." in doc.fname - } - total_size_kb = sum(doc.file_size_kb or 0 for doc in flat_docs) - total_size_mb = round(total_size_kb / 1024, 2) - span.set_attribute("collection.documents.count", len(flat_docs)) - span.set_attribute("collection.documents.total_size_mb", total_size_mb) + provider.upload_files(storage, flat_docs, project_id) + + logger.info( + "[create_collection.execute_setup_job] All file uploads complete | " + "job_id=%s, total=%d", + job_id, + len(flat_docs), + ) + + total_size_kb = sum(doc.file_size_kb for doc in flat_docs) + total_size_mb = total_size_kb / 1024 + + docs_batches = batch_documents(flat_docs) + total_batches = len(docs_batches) + batch_doc_ids = [[str(doc.id) for doc in batch] for batch in docs_batches] with Session(engine) as session: collection_job_crud = CollectionJobCrud(session, project_id) - collection_job = collection_job_crud.read_one(job_uuid) collection_job = collection_job_crud.update( job_uuid, CollectionJobUpdate( task_id=task_id, status=CollectionJobStatus.PROCESSING, total_size_mb=total_size_mb, + current_batch_number=0, + total_batches=total_batches, + documents_uploaded=[], ), ) - storage = get_cloud_storage(session=session, project_id=project_id) + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + batch_number=1, + batch_doc_ids=batch_doc_ids[0], + remaining_batches=batch_doc_ids[1:], + request=request, + vector_store_id=None, + with_assistant=with_assistant, + organization_id=organization_id, + ) + + logger.info( + "[create_collection.execute_setup_job] Setup complete, first batch queued | " + f"job_id={job_id}, total_batches={total_batches}" + ) + + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError("Task exceeded soft time limit") + logger.warning( + "[create_collection.execute_setup_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(timeout_err), + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + timeout_err, + collection_job, + creation_request, + ) + raise + + except Exception as err: + logger.error( + "[create_collection.execute_setup_job] Setup failed | job_id=%s, error=%s", + job_id, + str(err), + exc_info=True, + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + err, + collection_job, + creation_request, + ) + raise + + +def execute_batch_job( + request: dict, + with_assistant: bool, + project_id: int, + organization_id: int, + task_id: str, + job_id: str, + task_instance, + vector_store_id: str | None, + batch_number: int, + batch_doc_ids: list[str], + remaining_batches: list[list[str]], +) -> None: + """ + Phase 2: Upload one batch of documents to the vector store. + - Uploads the batch via provider.create(); raises immediately on failure + - Checkpoints progress to the DB + - If more batches remain, queues the next batch task + - If this is the last batch, finalizes: creates Collection, links docs, marks job SUCCESSFUL + """ + collection_job = None + result = None + creation_request = None + provider = None + + with log_context( + tag="collection", + lifecycle="collection.create.execute_batch_job", + action="create", + collection_job_id=job_id, + task_id=task_id, + project_id=project_id, + organization_id=organization_id, + ), tracer.start_as_current_span("collections.create.execute_batch_job") as span: + span.set_attribute("collection.job_id", str(job_id)) + span.set_attribute("kaapi.project_id", project_id) + span.set_attribute("kaapi.organization_id", organization_id) + + try: + batch_start_time = time.time() + creation_request = CreationRequest(**request) + if with_assistant: + creation_request.provider = "openai" + + span.set_attribute("collection.provider", str(creation_request.provider)) + + job_uuid = UUID(job_id) + trace_id = correlation_id.get() or "N/A" + + logger.info( + "[create_collection.execute_batch_job] Starting batch | " + "job_id=%s, batch_number=%d, doc_count=%d, remaining_batches=%d", + job_id, + batch_number, + len(batch_doc_ids), + len(remaining_batches), + ) + + all_doc_ids_this_batch = [UUID(d) for d in batch_doc_ids] + is_final = not remaining_batches + + with Session(engine) as session: provider = get_llm_provider( session=session, provider=creation_request.provider, @@ -261,38 +377,107 @@ def execute_job( organization_id=organization_id, ) - with tracer.start_as_current_span("collections.create.provider"): - result = provider.create( - collection_request=creation_request, - storage=storage, - documents=flat_docs, + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + batch_docs = ( + document_crud.read_each(all_doc_ids_this_batch) + if all_doc_ids_this_batch + else [] ) + for doc in batch_docs: + session.expunge(doc) - llm_service_id = result.llm_service_id - llm_service_name = result.llm_service_name + collection_result = provider.create( + creation_request, + batch_docs, + vector_store_id=vector_store_id, + is_final=is_final, + ) + result = collection_result + resolved_vector_store_id = collection_result.llm_service_id with Session(engine) as session: - collection_crud = CollectionCrud(session, project_id) - collection_id = uuid4() + collection_job_crud = CollectionJobCrud(session, project_id) + collection_job = collection_job_crud.read_one(job_uuid) + already_uploaded = collection_job.documents_uploaded or [] + now_uploaded = already_uploaded + [ + str(d) for d in all_doc_ids_this_batch + ] + collection_job = collection_job_crud.update( + job_uuid, + CollectionJobUpdate( + current_batch_number=batch_number, + documents_uploaded=now_uploaded, + ), + ) + + logger.info( + "[create_collection.execute_batch_job] Batch %d complete | " + "doc_count=%d, job_id=%s", + batch_number, + len(all_doc_ids_this_batch), + job_id, + ) + + if remaining_batches: + start_collection_batch_job( + project_id=project_id, + job_id=job_id, + trace_id=trace_id, + vector_store_id=resolved_vector_store_id, + batch_number=batch_number + 1, + batch_doc_ids=remaining_batches[0], + remaining_batches=remaining_batches[1:], + request=request, + with_assistant=with_assistant, + organization_id=organization_id, + ) + logger.info( + "[create_collection.execute_batch_job] Batch %d/%d done, next batch queued | " + "job_id=%s, elapsed=%.2fs", + batch_number, + batch_number + len(remaining_batches), + job_id, + time.time() - batch_start_time, + ) + return + + # Final batch: collection_result already has assistant/vector_store finalized + finalize_start_time = time.time() + + with Session(engine) as session: + all_uploaded_ids = [UUID(d) for d in now_uploaded] + document_crud = DocumentCrud(session, project_id) + all_docs = ( + document_crud.read_each(all_uploaded_ids) + if all_uploaded_ids + else [] + ) + for doc in all_docs: + session.expunge(doc) + + with Session(engine) as session: + collection_id = uuid4() collection = Collection( id=collection_id, project_id=project_id, - llm_service_id=llm_service_id, - llm_service_name=llm_service_name, + llm_service_id=collection_result.llm_service_id, + llm_service_name=collection_result.llm_service_name, provider=creation_request.provider, name=creation_request.name, description=creation_request.description, ) + collection_crud = CollectionCrud(session, project_id) collection_crud.create(collection) collection = collection_crud.read_one(collection.id) - if flat_docs: - DocumentCollectionCrud(session).create(collection, flat_docs) + if all_docs: + DocumentCollectionCrud(session).create(collection, all_docs) collection_job_crud = CollectionJobCrud(session, project_id) collection_job = collection_job_crud.update( - collection_job.id, + job_uuid, CollectionJobUpdate( status=CollectionJobStatus.SUCCESSFUL, collection_id=collection.id, @@ -303,14 +488,13 @@ def execute_job( span.set_attribute("collection.id", str(collection_id)) - elapsed = time.time() - start_time logger.info( - "[create_collection.execute_job] Collection created: %s | Time: %.2fs | Files: %d | Total Size: %s MB | Types: %s", + "[create_collection.execute_batch_job] All batches done, collection created: %s | " + "finalize_time=%.2fs, total_time=%.2fs, total_docs=%d", collection_id, - elapsed, - len(flat_docs), - collection_job.total_size_mb, - list(file_exts), + time.time() - finalize_start_time, + time.time() - batch_start_time, + len(all_docs), ) if creation_request.callback_url: @@ -324,7 +508,7 @@ def execute_job( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") logger.warning( - "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + "[create_collection.execute_batch_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(timeout_err), ) @@ -343,7 +527,7 @@ def execute_job( except Exception as err: logger.error( - "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", + "[create_collection.execute_batch_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(err), exc_info=True, diff --git a/backend/app/services/collections/helpers.py b/backend/app/services/collections/helpers.py index 6985ac78e..3f0a0cefd 100644 --- a/backend/app/services/collections/helpers.py +++ b/backend/app/services/collections/helpers.py @@ -19,7 +19,6 @@ MAX_DOC_SIZE_MB = 25 # 25 MB maximum per document # Maximum batch size for uploading documents to vector store -# Derived from MAX_DOC_SIZE + buffer to ensure single docs always fit MAX_BATCH_SIZE_KB = (MAX_DOC_SIZE_MB + 5) * 1024 # 30 MB in KB (25 + 5 MB buffer) MAX_BATCH_COUNT = 200 # Maximum documents per batch @@ -83,7 +82,7 @@ def batch_documents(documents: list[Document]) -> list[list[Document]]: current_batch_size_kb = 0 for doc in documents: - doc_size_kb = doc.file_size_kb or 0 + doc_size_kb = doc.file_size_kb would_exceed_size = (current_batch_size_kb + doc_size_kb) > MAX_BATCH_SIZE_KB would_exceed_count = len(current_batch) >= MAX_BATCH_COUNT diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index 36283d1fa..6649a0725 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -19,48 +19,46 @@ class BaseProvider(ABC): """ def __init__(self, client: Any) -> None: - """Initialize provider with client. + self.client = client + + @abstractmethod + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + """Upload all documents to the provider's file storage and persist their file IDs. Args: - client: Provider-specific client instance + storage: Cloud storage instance to fetch raw file bytes from + docs: Documents to upload + project_id: Project ID used to persist the provider file IDs to the DB """ - self.client = client + raise NotImplementedError("Providers must implement upload_files method") @abstractmethod def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: list[Document], + docs: list[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """Create collection with documents and optionally an assistant. - - Args: - collection_request: Collection parameters (name, description, document list, etc.) - storage: Cloud storage instance for file access - documents: Pre-fetched list of Document objects to add to the collection - - Returns: - Collection object with llm_service_id and llm_service_name populated - """ - raise NotImplementedError("Providers must implement execute method") + """Upload docs batch to vector store (creating it if vector_store_id is None). + Creates assistant only when is_final=True and model/instructions are set. + Returns Collection with llm_service_id set to vector_store_id on intermediate batches, + or to assistant/vector_store id on the final batch.""" + raise NotImplementedError("Providers must implement create method") @abstractmethod def delete(self, collection: Collection) -> None: - """Delete remote resources associated with a collection. - - Called when a collection is being deleted and remote resources need to be cleaned up. - - Args: - llm_service_id: ID of the resource to delete - llm_service_name: Name of the service (determines resource type) - """ + """Delete remote resources associated with a collection.""" raise NotImplementedError("Providers must implement delete method") - def get_provider_name(self) -> str: - """Get the name of the provider. + def get_existing_file_id(self, _doc: Document) -> str | None: + """Return the already-uploaded file ID for this provider, or None to trigger upload.""" + return None - Returns: - Provider name (e.g., "openai", "bedrock", "pinecone") - """ + def get_provider_name(self) -> str: return self.__class__.__name__.replace("Provider", "").lower() diff --git a/backend/app/services/collections/providers/openai.py b/backend/app/services/collections/providers/openai.py index f52e83394..61e7c6374 100644 --- a/backend/app/services/collections/providers/openai.py +++ b/backend/app/services/collections/providers/openai.py @@ -1,12 +1,16 @@ import logging +from io import BytesIO from typing import List from openai import OpenAI +from sqlmodel import Session from app.services.collections.providers import BaseProvider from app.core.cloud.storage import CloudStorage +from app.core.db import engine +from app.crud import DocumentCrud from app.crud.rag import OpenAIVectorStoreCrud, OpenAIAssistantCrud -from app.services.collections.helpers import get_service_name, batch_documents +from app.services.collections.helpers import get_service_name from app.models import CreationRequest, Collection, Document @@ -20,29 +24,73 @@ def __init__(self, client: OpenAI): super().__init__(client) self.client = client + def get_existing_file_id(self, doc: Document) -> str | None: + return doc.openai_file_id + + def upload_files( + self, + storage: CloudStorage, + docs: list[Document], + project_id: int, + ) -> None: + for doc in docs: + if self.get_existing_file_id(doc): + continue + try: + content = storage.get(doc.object_store_url) + if doc.file_size_kb is None: + doc.file_size_kb = round(len(content) / 1024, 2) + f_obj = BytesIO(content) + f_obj.name = doc.fname + uploaded = self.client.files.create(file=f_obj, purpose="assistants") + doc.openai_file_id = uploaded.id + with Session(engine) as session: + document_crud = DocumentCrud(session, project_id) + db_doc = document_crud.read_one(doc.id) + db_doc.openai_file_id = uploaded.id + db_doc.file_size_kb = doc.file_size_kb + document_crud.update(db_doc) + except Exception as err: + logger.error( + "[OpenAIProvider.upload_files] Failed to upload file | doc_id=%s, error=%s", + doc.id, + str(err), + exc_info=True, + ) + raise + def create( self, collection_request: CreationRequest, - storage: CloudStorage, - documents: List[Document], + docs: List[Document], + vector_store_id: str | None = None, + is_final: bool = False, ) -> Collection: - """ - Create OpenAI vector store with documents and optionally an assistant. - docs_batches must be pre-fetched inside a DB session before this call. - """ try: - docs_batches = batch_documents(documents) vector_store_crud = OpenAIVectorStoreCrud(self.client) - vector_store = vector_store_crud.create() - list(vector_store_crud.update(vector_store.id, storage, docs_batches)) + if vector_store_id is None: + vector_store = vector_store_crud.create() + vector_store_id = vector_store.id + logger.info( + "[OpenAIProvider.create] Vector store created | vector_store_id=%s", + vector_store_id, + ) - logger.info( - "[OpenAIProvider.create] Vector store created | " - f"vector_store_id={vector_store.id}, batches={len(docs_batches)}" - ) + if docs: + vector_store_crud.update(vector_store_id, docs) + logger.info( + "[OpenAIProvider.create] Batch uploaded | vector_store_id=%s, doc_count=%d", + vector_store_id, + len(docs), + ) - # Check if we need to create an assistant (based on assistant options in request) + if not is_final: + return Collection( + llm_service_id=vector_store_id, + llm_service_name=get_service_name("openai"), + ) + # if "is_final" is true then only will assistant creation happen - with_assistant = ( collection_request.model is not None and collection_request.instructions is not None @@ -59,11 +107,12 @@ def create( k: v for k, v in assistant_options.items() if v is not None } - assistant = assistant_crud.create(vector_store.id, **filtered_options) + assistant = assistant_crud.create(vector_store_id, **filtered_options) logger.info( - "[OpenAIProvider.create] Assistant created | " - f"assistant_id={assistant.id}, vector_store_id={vector_store.id}" + "[OpenAIProvider.create] Assistant created | assistant_id=%s, vector_store_id=%s", + assistant.id, + vector_store_id, ) return Collection( @@ -76,7 +125,7 @@ def create( ) return Collection( - llm_service_id=vector_store.id, + llm_service_id=vector_store_id, llm_service_name=get_service_name("openai"), ) diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index b21577d49..f48d04d63 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -1,5 +1,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest @@ -23,7 +24,6 @@ def test_create_openai_vector_store_only() -> None: temperature=None, ) - storage = MagicMock() documents = [ SimpleNamespace(file_size_kb=10), SimpleNamespace(file_size_kb=20), @@ -35,11 +35,10 @@ def test_create_openai_vector_store_only() -> None: ) as vector_store_crud_cls: vector_store_crud = vector_store_crud_cls.return_value vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = iter([None]) + vector_store_crud.update.return_value = None collection = provider.create( collection_request, - storage, documents, ) @@ -59,7 +58,6 @@ def test_create_openai_with_assistant() -> None: temperature=0.7, ) - storage = MagicMock() documents = [SimpleNamespace(file_size_kb=10)] vector_store_id = generate_openai_id("vs_") assistant_id = generate_openai_id("asst_") @@ -71,15 +69,15 @@ def test_create_openai_with_assistant() -> None: ) as assistant_crud_cls: vector_store_crud = vector_store_crud_cls.return_value vector_store_crud.create.return_value = MagicMock(id=vector_store_id) - vector_store_crud.update.return_value = iter([None]) + vector_store_crud.update.return_value = None assistant_crud = assistant_crud_cls.return_value assistant_crud.create.return_value = MagicMock(id=assistant_id) collection = provider.create( collection_request, - storage, documents, + is_final=True, ) assert collection.llm_service_id == assistant_id @@ -124,6 +122,204 @@ def test_delete_openai_vector_store() -> None: vector_store_crud.delete.assert_called_once_with(collection.llm_service_id) +# --------------------------------------------------------------------------- +# upload_files +# --------------------------------------------------------------------------- + + +def _make_doc(*, openai_file_id=None, file_size_kb=None): + return SimpleNamespace( + id=uuid4(), + fname="test.md", + object_store_url="s3://bucket/test.md", + openai_file_id=openai_file_id, + file_size_kb=file_size_kb, + ) + + +def _patch_session_and_crud(): + """Patches Session and DocumentCrud used inside upload_files.""" + session_patcher = patch("app.services.collections.providers.openai.Session") + crud_patcher = patch("app.services.collections.providers.openai.DocumentCrud") + return session_patcher, crud_patcher + + +def test_upload_files_skips_doc_with_existing_openai_file_id() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + storage = MagicMock() + doc = _make_doc(openai_file_id="file-already-exists", file_size_kb=10.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + provider.upload_files(storage, [doc], project_id=1) + + storage.get.assert_not_called() + client.files.create.assert_not_called() + + +def test_upload_files_uploads_doc_and_sets_openai_file_id() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-new-abc") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"file content" + + doc = _make_doc(file_size_kb=10.0) + + mock_crud = MagicMock() + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.openai_file_id == "file-new-abc" + client.files.create.assert_called_once() + _, kwargs = client.files.create.call_args + assert kwargs.get("purpose") == "assistants" + mock_crud.update.assert_called_once() + + +def test_upload_files_sets_file_size_kb_when_none() -> None: + """file_size_kb should be computed from content length if not already set.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-xyz") + provider = OpenAIProvider(client=client) + + content = b"x" * 2048 # 2 KB + storage = MagicMock() + storage.get.return_value = content + + doc = _make_doc(file_size_kb=None) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.file_size_kb == round(len(content) / 1024, 2) + + +def test_upload_files_preserves_existing_file_size_kb() -> None: + """file_size_kb should not be overwritten if already set.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-xyz") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"x" * 4096 + + doc = _make_doc(file_size_kb=99.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [doc], project_id=1) + + assert doc.file_size_kb == 99.0 + + +def test_upload_files_updates_db_with_file_id_and_size() -> None: + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-db-check") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc(file_size_kb=5.0) + mock_db_doc = MagicMock() + mock_crud = MagicMock() + mock_crud.read_one.return_value = mock_db_doc + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p as MockDocCrud: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + MockDocCrud.return_value = mock_crud + + provider.upload_files(storage, [doc], project_id=42) + + MockDocCrud.assert_called_once_with( + MockSession.return_value.__enter__.return_value, 42 + ) + mock_crud.read_one.assert_called_once_with(doc.id) + assert mock_db_doc.openai_file_id == "file-db-check" + assert mock_db_doc.file_size_kb == 5.0 + mock_crud.update.assert_called_once_with(mock_db_doc) + + +def test_upload_files_raises_on_storage_failure() -> None: + client = MagicMock() + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.side_effect = RuntimeError("S3 error") + + doc = _make_doc() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="S3 error"): + provider.upload_files(storage, [doc], project_id=1) + + client.files.create.assert_not_called() + + +def test_upload_files_raises_on_openai_failure() -> None: + client = MagicMock() + client.files.create.side_effect = RuntimeError("OpenAI error") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + doc = _make_doc() + + session_p, crud_p = _patch_session_and_crud() + with session_p, crud_p: + with pytest.raises(RuntimeError, match="OpenAI error"): + provider.upload_files(storage, [doc], project_id=1) + + +def test_upload_files_mixed_skips_uploaded_uploads_new() -> None: + """Docs with openai_file_id are skipped; others are uploaded.""" + client = MagicMock() + client.files.create.return_value = MagicMock(id="file-new") + provider = OpenAIProvider(client=client) + + storage = MagicMock() + storage.get.return_value = b"content" + + already_uploaded = _make_doc(openai_file_id="file-exists", file_size_kb=5.0) + new_doc = _make_doc(file_size_kb=5.0) + + session_p, crud_p = _patch_session_and_crud() + with session_p as MockSession, crud_p: + MockSession.return_value.__enter__.return_value = MagicMock() + MockSession.return_value.__exit__.return_value = False + + provider.upload_files(storage, [already_uploaded, new_doc], project_id=1) + + assert already_uploaded.openai_file_id == "file-exists" + assert new_doc.openai_file_id == "file-new" + client.files.create.assert_called_once() + storage.get.assert_called_once_with(new_doc.object_store_url) + + +# --------------------------------------------------------------------------- +# create (existing tests below) +# --------------------------------------------------------------------------- + + def test_create_propagates_exception() -> None: provider = OpenAIProvider(client=MagicMock()) @@ -142,6 +338,5 @@ def test_create_propagates_exception() -> None: with pytest.raises(RuntimeError): provider.create( collection_request, - MagicMock(), [SimpleNamespace(file_size_kb=10)], ) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index d8ca2829b..05213b879 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -1,26 +1,27 @@ from typing import Any import os -from pathlib import Path from unittest.mock import patch, MagicMock -from urllib.parse import urlparse import uuid from uuid import UUID, uuid4 +from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout import pytest -from moto import mock_aws from sqlmodel import Session -from app.core.cloud import AmazonCloudStorageClient from app.core.config import settings from app.crud import CollectionCrud, CollectionJobCrud, DocumentCollectionCrud from app.models import CollectionJobStatus, CollectionJob, CollectionActionType, Project from app.models.collection import CreationRequest -from app.services.collections.create_collection import start_job, execute_job +from app.services.collections.create_collection import ( + start_job, + execute_setup_job, + execute_batch_job, +) from app.tests.utils.llm_provider import get_mock_provider from app.tests.utils.utils import get_project -from app.tests.utils.collection import get_collection_job, get_assistant_collection +from app.tests.utils.collection import get_collection_job from app.tests.utils.document import DocumentStore @@ -33,30 +34,33 @@ def aws_credentials() -> Any: os.environ["AWS_DEFAULT_REGION"] = settings.AWS_DEFAULT_REGION -def create_collection_job_for_create( - db: Session, - project: Project, - job_id: UUID, -) -> CollectionJob: - """Pre-create a CREATE job with the given id so start_job can update it.""" - return CollectionJobCrud(db, project.id).create( - CollectionJob( - id=job_id, - action_type=CollectionActionType.CREATE, - project_id=project.id, - collection_id=None, - status=CollectionJobStatus.PENDING, - ) - ) +def _mock_provider_with_size(llm_service_id: str, llm_service_name: str): + """Returns a mock provider whose upload_files sets file_size_kb=10.0 on each doc.""" + mock_provider = get_mock_provider(llm_service_id, llm_service_name) + + def _set_file_size(storage, docs, project_id): + for doc in docs: + doc.file_size_kb = 10.0 + + mock_provider.upload_files.side_effect = _set_file_size + return mock_provider + + +def _patch_session(db: Session): + """Context manager that routes all Session(engine) calls to the test db.""" + patcher = patch("app.services.collections.create_collection.Session") + mock_ctor = patcher.start() + mock_ctor.return_value.__enter__.return_value = db + mock_ctor.return_value.__exit__.return_value = False + return patcher + + +# --------------------------------------------------------------------------- +# start_job +# --------------------------------------------------------------------------- def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> None: - """ - start_job should: - - update an existing CollectionJob (status=PENDING, action=CREATE) - - call start_create_collection_job with the correct kwargs - - return the job UUID (same one that was passed in) - """ project = get_project(db) request = CreationRequest( documents=[UUID("f3e86a17-1e6f-41ec-b020-5b08eebef928")], @@ -65,7 +69,7 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non ) job_id = uuid4() - _ = get_collection_job( + get_collection_job( db, project, job_id=job_id, @@ -88,472 +92,569 @@ def test_start_job_creates_collection_job_and_schedules_task(db: Session) -> Non organization_id=project.organization_id, ) - assert returned_job_id == job_id + assert returned_job_id == job_id + mock_schedule.assert_called_once() + kwargs = mock_schedule.call_args.kwargs + assert kwargs["project_id"] == project.id + assert kwargs["organization_id"] == project.organization_id + assert kwargs["job_id"] == str(job_id) + assert kwargs["request"] == request.model_dump(mode="json") - job = CollectionJobCrud(db, project.id).read_one(job_id) - assert job.id == job_id - assert job.project_id == project.id - assert job.status == CollectionJobStatus.PENDING - assert job.action_type in ( - CollectionActionType.CREATE, - CollectionActionType.CREATE.value, - ) - assert job.collection_id is None - mock_schedule.assert_called_once() - kwargs = mock_schedule.call_args.kwargs - assert kwargs["project_id"] == project.id - assert kwargs["organization_id"] == project.organization_id - assert kwargs["job_id"] == str(job_id) - assert kwargs["request"] == request.model_dump(mode="json") +# --------------------------------------------------------------------------- +# execute_setup_job +# --------------------------------------------------------------------------- -@pytest.mark.usefixtures("aws_credentials") -@mock_aws +@patch("app.services.collections.create_collection.get_cloud_storage") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_success_flow_updates_job_and_creates_collection( - mock_get_llm_provider: MagicMock, db: Session +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_marks_processing_and_queues_first_batch( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - sample_request = CreationRequest( - documents=[document.id], callback_url=None, provider="openai" + mock_get_provider.return_value = _mock_provider_with_size( + "vs_123", "openai vector store" ) - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" - ) - - job_id = uuid4() - _ = get_collection_job( + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) - - task_id = uuid4() - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - execute_job( - request=sample_request.model_dump(), + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + task_id = str(uuid4()) + + patcher = _patch_session(db) + try: + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=task_id, + job_id=str(job.id), task_instance=None, ) + finally: + patcher.stop() - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) - assert updated_job.task_id == str(task_id) - assert updated_job.status == CollectionJobStatus.SUCCESSFUL - assert updated_job.collection_id is not None - - created_collection = CollectionCrud(db, project.id).read_one( - updated_job.collection_id - ) - assert created_collection.llm_service_id == "mock_vector_store_id" - assert created_collection.llm_service_name == "openai vector store" - assert created_collection.updated_at is not None + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.PROCESSING + assert updated_job.task_id == task_id - docs = DocumentCollectionCrud(db).read(created_collection, skip=0, limit=10) - assert len(docs) == 1 - assert docs[0].fname == document.fname + mock_queue_batch.assert_called_once() + kw = mock_queue_batch.call_args.kwargs + assert kw["batch_number"] == 1 + assert kw["vector_store_id"] is None + assert str(doc.id) in kw["batch_doc_ids"] + assert kw["remaining_batches"] == [] -@pytest.mark.usefixtures("aws_credentials") -@mock_aws +@patch("app.services.collections.create_collection.get_cloud_storage") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_assistant_create_failure_marks_failed_and_deletes_collection( - mock_get_llm_provider: MagicMock, db +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_failure_marks_job_failed_and_raises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, ) -> None: project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = RuntimeError("S3 upload failed") + mock_get_provider.return_value = mock_provider job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - req = CreationRequest(documents=[], callback_url=None, provider="openai") - - mock_provider = get_mock_provider( - llm_service_id="vs_123", llm_service_name="openai vector store" - ) - mock_get_llm_provider.return_value = mock_provider - - with patch( - "app.services.collections.create_collection.Session" - ) as SessionCtor, patch( - "app.services.collections.create_collection.CollectionCrud" - ) as MockCrud: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - MockCrud.return_value.create.side_effect = Exception("DB constraint violation") - - task_id = str(uuid4()) - with pytest.raises(Exception, match="DB constraint violation"): - execute_job( - request=req.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError, match="S3 upload failed"): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=task_id, - with_assistant=True, + task_id=str(uuid4()), job_id=str(job.id), task_instance=None, ) + finally: + patcher.stop() - mock_provider.delete.assert_called_once() + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "S3 upload failed" in (updated_job.error_message or "") + mock_queue_batch.assert_not_called() -@pytest.mark.usefixtures("aws_credentials") -@mock_aws -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -def test_execute_job_success_flow_callback_job_and_creates_collection( +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_failure_sends_callback( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, - db, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - callback_url = "https://example.com/collections/create-success" + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = RuntimeError("upload error") + mock_get_provider.return_value = mock_provider - sample_request = CreationRequest( - documents=[document.id], - callback_url=callback_url, - provider="openai", + callback_url = "https://example.com/callback" + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, ) - - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="openai vector store" + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - job_id = uuid.uuid4() - _ = get_collection_job( + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + mock_send_callback.assert_called_once() + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is False + assert payload["data"]["status"] == CollectionJobStatus.FAILED + + +@patch("app.services.collections.create_collection.get_cloud_storage") +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_setup_job_timeout_marks_failed_and_reraises( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + mock_get_storage: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = _mock_provider_with_size("vs_123", "openai vector store") + mock_provider.upload_files.side_effect = Timeout(300) + mock_get_provider.return_value = mock_provider + + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, status=CollectionJobStatus.PENDING, - collection_id=None, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - task_id = uuid.uuid4() + patcher = _patch_session(db) + try: + with pytest.raises(Timeout): + execute_setup_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") + + +# --------------------------------------------------------------------------- +# execute_batch_job +# --------------------------------------------------------------------------- + + +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_non_final_queues_next_batch( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc1 = store.put() + doc2 = store.put() - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + mock_get_provider.return_value = get_mock_provider("vs_123", "openai vector store") - mock_send_callback.return_value = MagicMock(status_code=403) + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc1.id, doc2.id], provider="openai", callback_url=None + ) + task_id = str(uuid4()) - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=task_id, + job_id=str(job.id), task_instance=None, + vector_store_id="vs_123", + batch_number=1, + batch_doc_ids=[str(doc1.id)], + remaining_batches=[[str(doc2.id)]], ) + finally: + patcher.stop() - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) - collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + mock_queue_batch.assert_called_once() + kw = mock_queue_batch.call_args.kwargs + assert kw["batch_number"] == 2 + assert kw["batch_doc_ids"] == [str(doc2.id)] + assert kw["remaining_batches"] == [] + assert kw["vector_store_id"] == "vs_123" - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is True - assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL - assert payload_arg["data"]["collection"]["id"] == str(collection.id) - assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.current_batch_number == 1 + assert str(doc1.id) in (updated_job.documents_uploaded or []) -@pytest.mark.usefixtures("aws_credentials") -@mock_aws @patch("app.services.collections.create_collection.get_llm_provider") -@patch("app.services.collections.create_collection.send_callback") -def test_execute_job_success_creates_collection_with_callback( - mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, - db, +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_final_batch_creates_collection_and_marks_successful( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, + db: Session, ) -> None: - """ - execute_job should: - - set task_id on the CollectionJob - - ingest documents into a vector store - - create an OpenAI assistant - - create a Collection with llm fields filled - - link the CollectionJob -> collection_id, set status=successful - - create DocumentCollection links - """ project = get_project(db) - - aws = AmazonCloudStorageClient() - aws.create() - store = DocumentStore(db=db, project_id=project.id) - document = store.put() - s3_key = Path(urlparse(document.object_store_url).path).relative_to("/") - aws.client.put_object(Bucket=settings.AWS_S3_BUCKET, Key=str(s3_key), Body=b"test") + doc = store.put() - callback_url = "https://example.com/collections/create-success" - - sample_request = CreationRequest( - documents=[document.id], - callback_url=callback_url, - provider="openai", - ) - - mock_get_llm_provider.return_value = get_mock_provider( - llm_service_id="mock_vector_store_id", llm_service_name="gpt-4o" + mock_get_provider.return_value = get_mock_provider( + "vs_final", "openai vector store" ) - job_id = uuid.uuid4() - _ = get_collection_job( + job = get_collection_job( db, project, - job_id=job_id, action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - task_id = uuid.uuid4() - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - mock_send_callback.return_value = MagicMock(status_code=403) - - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, - job_id=str(job_id), + task_id=str(uuid4()), + job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.SUCCESSFUL + assert updated_job.collection_id is not None - updated_job = CollectionJobCrud(db, project.id).read_one(job_id) collection = CollectionCrud(db, project.id).read_one(updated_job.collection_id) + assert collection.llm_service_id == "vs_final" - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is True - assert payload_arg["data"]["status"] == CollectionJobStatus.SUCCESSFUL - assert payload_arg["data"]["collection"]["id"] == str(collection.id) - assert uuid.UUID(payload_arg["data"]["job_id"]) == job_id + linked_docs = DocumentCollectionCrud(db).read(collection, skip=0, limit=10) + assert len(linked_docs) == 1 + assert linked_docs[0].id == doc.id + + mock_queue_batch.assert_not_called() -@pytest.mark.usefixtures("aws_credentials") -@mock_aws -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -@patch("app.services.collections.create_collection.CollectionCrud") -def test_execute_job_failure_flow_callback_job_and_marks_failed( - MockCollectionCrud, +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.start_collection_batch_job") +def test_execute_batch_job_final_batch_sends_success_callback( + mock_queue_batch: MagicMock, + mock_get_provider: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, db: Session, ) -> None: - """ - When creation fails, the job should be marked as FAILED, an error should be logged, - and a failure callback with the error message should be triggered. - """ project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_get_provider.return_value = get_mock_provider( + "vs_final", "openai vector store" + ) - collection = get_assistant_collection(db, project, assistant_id="asst_123") + callback_url = "https://example.com/success" job = get_collection_job( db, project, action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - mock_get_llm_provider.return_value = MagicMock() + patcher = _patch_session(db) + try: + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() - callback_url = "https://example.com/collections/create-failure" + mock_send_callback.assert_called_once() + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is True + assert payload["data"]["status"] == CollectionJobStatus.SUCCESSFUL + assert payload["data"]["collection"] is not None - collection_crud_instance = MockCollectionCrud.return_value - collection_crud_instance.read_one.return_value = collection - sample_request = CreationRequest( - documents=[uuid.uuid4()], - callback_url=callback_url, - provider="openai", - ) +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_provider_failure_marks_failed_and_raises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - task_id = uuid.uuid4() + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = RuntimeError("vector store error") + mock_get_provider.return_value = mock_provider - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - with pytest.raises( - ValueError, match="Requested atleast 1 document retrieved 0" - ): - execute_job( - request=sample_request.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError, match="vector store error"): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, - task_id=str(task_id), - with_assistant=True, + task_id=str(uuid4()), job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() updated_job = CollectionJobCrud(db, project.id).read_one(job.id) - assert updated_job.status == CollectionJobStatus.FAILED - assert "Requested atleast 1 document retrieved 0" in ( - updated_job.error_message or "" - ) - - mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is False - assert "Requested atleast 1 document retrieved 0" in (payload_arg["error"] or "") - assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED - assert payload_arg["data"]["collection"] is None - assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + assert "vector store error" in (updated_job.error_message or "") @patch("app.services.collections.create_collection.get_llm_provider") -def test_execute_job_timeout_marks_job_failed( - mock_get_llm_provider: MagicMock, db: Session +@patch("app.services.collections.create_collection.CollectionCrud") +def test_execute_batch_job_cleanup_called_when_provider_create_succeeds_but_db_fails( + MockCollectionCrud: MagicMock, + mock_get_provider: MagicMock, + db: Session, ) -> None: + """provider.delete should be called if create() succeeded but finalization fails.""" project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_get_provider.return_value = mock_provider + + MockCollectionCrud.return_value.create.side_effect = Exception("DB write failed") job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) - mock_provider = MagicMock() - mock_provider.create.side_effect = Timeout(300) - mock_get_llm_provider.return_value = mock_provider + patcher = _patch_session(db) + try: + with pytest.raises(Exception, match="DB write failed"): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], + ) + finally: + patcher.stop() + + mock_provider.delete.assert_called_once() + + +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_timeout_marks_failed_and_reraises( + mock_get_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() - req = CreationRequest(documents=[], callback_url=None, provider="openai") + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = Timeout(300) + mock_get_provider.return_value = mock_provider - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False + job = get_collection_job( + db, + project, + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest(documents=[doc.id], provider="openai", callback_url=None) + patcher = _patch_session(db) + try: with pytest.raises(Timeout): - execute_job( - request=req.model_dump(), + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), - with_assistant=False, job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() updated_job = CollectionJobCrud(db, project.id).read_one(job.id) assert updated_job.status == CollectionJobStatus.FAILED assert "soft time limit" in (updated_job.error_message or "") -@patch("app.services.collections.create_collection.get_llm_provider") @patch("app.services.collections.create_collection.send_callback") -def test_execute_job_timeout_sends_failure_callback( +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_batch_job_failure_sends_callback( + mock_get_provider: MagicMock, mock_send_callback: MagicMock, - mock_get_llm_provider: MagicMock, db: Session, ) -> None: project = get_project(db) - callback_url = "https://example.com/collections/timeout" + store = DocumentStore(db=db, project_id=project.id) + doc = store.put() + + mock_provider = get_mock_provider("vs_123", "openai vector store") + mock_provider.create.side_effect = RuntimeError("batch failed") + mock_get_provider.return_value = mock_provider + callback_url = "https://example.com/failure" job = get_collection_job( db, project, - job_id=uuid4(), action_type=CollectionActionType.CREATE, - status=CollectionJobStatus.PENDING, - collection_id=None, + status=CollectionJobStatus.PROCESSING, + ) + request = CreationRequest( + documents=[doc.id], provider="openai", callback_url=callback_url ) - mock_provider = MagicMock() - mock_provider.create.side_effect = Timeout(300) - mock_get_llm_provider.return_value = mock_provider - - req = CreationRequest(documents=[], callback_url=callback_url, provider="openai") - - with patch("app.services.collections.create_collection.Session") as SessionCtor: - SessionCtor.return_value.__enter__.return_value = db - SessionCtor.return_value.__exit__.return_value = False - - with pytest.raises(Timeout): - execute_job( - request=req.model_dump(), + patcher = _patch_session(db) + try: + with pytest.raises(RuntimeError): + execute_batch_job( + request=request.model_dump(mode="json"), + with_assistant=False, project_id=project.id, organization_id=project.organization_id, task_id=str(uuid4()), - with_assistant=False, job_id=str(job.id), task_instance=None, + vector_store_id=None, + batch_number=1, + batch_doc_ids=[str(doc.id)], + remaining_batches=[], ) + finally: + patcher.stop() mock_send_callback.assert_called_once() - cb_url_arg, payload_arg = mock_send_callback.call_args.args - assert str(cb_url_arg) == callback_url - assert payload_arg["success"] is False - assert "soft time limit" in (payload_arg["error"] or "") - assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED - assert payload_arg["data"]["collection"] is None - assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + cb_url, payload = mock_send_callback.call_args.args + assert str(cb_url) == callback_url + assert payload["success"] is False + assert "batch failed" in (payload["error"] or "") + assert payload["data"]["status"] == CollectionJobStatus.FAILED diff --git a/backend/app/tests/services/collections/test_helpers.py b/backend/app/tests/services/collections/test_helpers.py index 7cddaf305..8b43946a1 100644 --- a/backend/app/tests/services/collections/test_helpers.py +++ b/backend/app/tests/services/collections/test_helpers.py @@ -122,14 +122,12 @@ def test_batch_documents_mixed_size_batching() -> None: assert len(batches[2]) == 1 # 15 MB total -def test_batch_documents_with_none_file_size() -> None: - """Test that documents with None file_size are treated as 0 bytes.""" +def test_batch_documents_with_none_file_size_raises() -> None: + """Test that documents with None file_size raise TypeError — sizes must be backfilled before batching.""" docs = create_fake_documents(10, file_size_kb=None) - batches = helpers.batch_documents(docs) - # All files with None/0 size should fit in one batch (under both limits) - assert len(batches) == 1 - assert len(batches[0]) == 10 + with pytest.raises(TypeError): + helpers.batch_documents(docs) def test_batch_documents_empty_input() -> None: