diff --git a/backend/app/services/collections/providers/base.py b/backend/app/services/collections/providers/base.py index d76fb6189..b5d37bf9b 100644 --- a/backend/app/services/collections/providers/base.py +++ b/backend/app/services/collections/providers/base.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from typing import Any -from app.crud import DocumentCrud from app.core.cloud.storage import CloudStorage -from app.models import CreationRequest, Collection +from app.models import CreationRequest, Collection, Document class BaseProvider(ABC): @@ -32,17 +31,14 @@ def create( self, collection_request: CreationRequest, storage: CloudStorage, - document_crud: DocumentCrud, + docs_batches: list[list[Document]], ) -> Collection: """Create collection with documents and optionally an assistant. Args: collection_request: Collection parameters (name, description, document list, etc.) storage: Cloud storage instance for file access - document_crud: DocumentCrud instance for fetching documents - batch_size: Number of documents to process per batch - with_assistant: Whether to create an assistant/agent - assistant_options: Options for assistant creation (provider-specific) + docs_batches: Pre-fetched document batches (DB reads must happen before this call) Returns: llm_service_id: ID of the resource to delete diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index 27ab8de86..29f4fc349 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -63,21 +63,21 @@ def _setup(self) -> None: def _teardown(self, result: BlockResult) -> dict: """Finalize chain record, send callback, and update job status.""" - with Session(engine) as session: - if result.success: - final = LLMChainResponse( - response=result.response.response, - usage=result.usage, - provider_raw_response=result.response.provider_raw_response, - ) - callback_response = APIResponse.success_response( - data=final, metadata=self._request.request_metadata + if result.success: + final = LLMChainResponse( + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_response = APIResponse.success_response( + data=final, metadata=self._request.request_metadata + ) + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), ) - if self._request.callback_url: - send_callback( - callback_url=str(self._request.callback_url), - data=callback_response.model_dump(), - ) + with Session(engine) as session: JobCrud(session).update( job_id=self._context.job_id, job_update=JobUpdate(status=JobStatus.SUCCESS), @@ -89,9 +89,9 @@ def _teardown(self, result: BlockResult) -> dict: output=result.response.response.output.model_dump(), total_usage=self._context.aggregated_usage.model_dump(), ) - return callback_response.model_dump() - else: - return self._handle_error(result.error) + return callback_response.model_dump() + else: + return self._handle_error(result.error) def _handle_error(self, error: str) -> dict: callback_response = APIResponse.failure_response( @@ -103,13 +103,13 @@ def _handle_error(self, error: str) -> dict: f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" ) - with Session(engine) as session: - if self._request.callback_url: - send_callback( - callback_url=str(self._request.callback_url), - data=callback_response.model_dump(), - ) + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + with Session(engine) as session: update_llm_chain_status( session, chain_id=self._context.chain_id, diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 0e2a983b0..d022a1edb 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -48,10 +48,6 @@ def start_job( job_crud = JobCrud(session=db) job = job_crud.create(job_type=JobType.LLM_API, trace_id=trace_id) - # Explicitly flush to ensure job is persisted before Celery task starts - db.flush() - db.commit() - logger.info( f"[start_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}" ) @@ -89,10 +85,6 @@ def start_chain_job( job_crud = JobCrud(session=db) job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id) - # Explicitly flush to ensure job is persisted before Celery task starts - db.flush() - db.commit() - logger.info( f"[start_chain_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}" ) @@ -129,16 +121,14 @@ def handle_job_error( callback_response: APIResponse, ) -> dict: """Handle job failure uniformly — send callback and update DB.""" - with Session(engine) as session: - job_crud = JobCrud(session=session) - - if callback_url: - send_callback( - callback_url=callback_url, - data=callback_response.model_dump(), - ) + if callback_url: + send_callback( + callback_url=callback_url, + data=callback_response.model_dump(), + ) - job_crud.update( + with Session(engine) as session: + JobCrud(session=session).update( job_id=job_id, job_update=JobUpdate( status=JobStatus.FAILED,