Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions backend/app/services/collections/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any

from app.crud import DocumentCrud
from app.core.cloud.storage import CloudStorage
from app.models import CreationRequest, Collection
from app.models import CreationRequest, Collection, Document


class BaseProvider(ABC):
Expand Down Expand Up @@ -32,17 +31,14 @@ def create(
self,
collection_request: CreationRequest,
storage: CloudStorage,
document_crud: DocumentCrud,
docs_batches: list[list[Document]],
) -> Collection:
"""Create collection with documents and optionally an assistant.

Args:
collection_request: Collection parameters (name, description, document list, etc.)
storage: Cloud storage instance for file access
document_crud: DocumentCrud instance for fetching documents
batch_size: Number of documents to process per batch
with_assistant: Whether to create an assistant/agent
assistant_options: Options for assistant creation (provider-specific)
docs_batches: Pre-fetched document batches (DB reads must happen before this call)

Returns:
llm_service_id: ID of the resource to delete
Expand Down
46 changes: 23 additions & 23 deletions backend/app/services/llm/chain/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ def _setup(self) -> None:
def _teardown(self, result: BlockResult) -> dict:
"""Finalize chain record, send callback, and update job status."""

with Session(engine) as session:
if result.success:
final = LLMChainResponse(
response=result.response.response,
usage=result.usage,
provider_raw_response=result.response.provider_raw_response,
)
callback_response = APIResponse.success_response(
data=final, metadata=self._request.request_metadata
if result.success:
final = LLMChainResponse(
response=result.response.response,
usage=result.usage,
provider_raw_response=result.response.provider_raw_response,
)
callback_response = APIResponse.success_response(
data=final, metadata=self._request.request_metadata
)
if self._request.callback_url:
send_callback(
callback_url=str(self._request.callback_url),
data=callback_response.model_dump(),
)
if self._request.callback_url:
send_callback(
callback_url=str(self._request.callback_url),
data=callback_response.model_dump(),
)
with Session(engine) as session:
JobCrud(session).update(
job_id=self._context.job_id,
job_update=JobUpdate(status=JobStatus.SUCCESS),
Expand All @@ -89,9 +89,9 @@ def _teardown(self, result: BlockResult) -> dict:
output=result.response.response.output.model_dump(),
total_usage=self._context.aggregated_usage.model_dump(),
)
return callback_response.model_dump()
else:
return self._handle_error(result.error)
return callback_response.model_dump()
else:
return self._handle_error(result.error)

def _handle_error(self, error: str) -> dict:
callback_response = APIResponse.failure_response(
Expand All @@ -103,13 +103,13 @@ def _handle_error(self, error: str) -> dict:
f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}"
)

with Session(engine) as session:
if self._request.callback_url:
send_callback(
callback_url=str(self._request.callback_url),
data=callback_response.model_dump(),
)
if self._request.callback_url:
send_callback(
callback_url=str(self._request.callback_url),
data=callback_response.model_dump(),
)

with Session(engine) as session:
update_llm_chain_status(
session,
chain_id=self._context.chain_id,
Expand Down
24 changes: 7 additions & 17 deletions backend/app/services/llm/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ def start_job(
job_crud = JobCrud(session=db)
job = job_crud.create(job_type=JobType.LLM_API, trace_id=trace_id)

# Explicitly flush to ensure job is persisted before Celery task starts
db.flush()
db.commit()

logger.info(
f"[start_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}"
)
Expand Down Expand Up @@ -89,10 +85,6 @@ def start_chain_job(
job_crud = JobCrud(session=db)
job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id)

# Explicitly flush to ensure job is persisted before Celery task starts
db.flush()
db.commit()

logger.info(
f"[start_chain_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}"
)
Expand Down Expand Up @@ -129,16 +121,14 @@ def handle_job_error(
callback_response: APIResponse,
) -> dict:
"""Handle job failure uniformly — send callback and update DB."""
with Session(engine) as session:
job_crud = JobCrud(session=session)

if callback_url:
send_callback(
callback_url=callback_url,
data=callback_response.model_dump(),
)
if callback_url:
send_callback(
callback_url=callback_url,
data=callback_response.model_dump(),
)

job_crud.update(
with Session(engine) as session:
JobCrud(session=session).update(
job_id=job_id,
job_update=JobUpdate(
status=JobStatus.FAILED,
Expand Down
Loading