Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""add batch tracking to collection_jobs

Revision ID: 058
Revises: 057
Create Date: 2026-04-13

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "058"
down_revision = "057"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"collection_jobs",
sa.Column(
"total_batches",
sa.Integer(),
nullable=True,
comment="Total number of batches the documents are split into",
),
)
op.add_column(
"collection_jobs",
sa.Column(
"current_batch_number",
sa.Integer(),
nullable=True,
comment="Which batch is currently being processed (1-indexed)",
),
)
op.add_column(
"collection_jobs",
sa.Column(
"documents_uploaded",
sa.JSON(),
nullable=True,
comment="List of document IDs successfully uploaded so far",
),
)
op.add_column(
"document",
sa.Column(
"openai_file_id",
sa.String(),
nullable=True,
comment="File ID assigned by the LLM provider (e.g. OpenAI file ID) to avoid re-uploading",
),
)


def downgrade():
op.drop_column("collection_jobs", "total_batches")
op.drop_column("collection_jobs", "current_batch_number")
op.drop_column("collection_jobs", "documents_uploaded")
op.drop_column("document", "openai_file_id")
2 changes: 1 addition & 1 deletion backend/app/api/docs/documents/upload.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Upload a document to Kaapi.

- If only a file is provided, the document will be uploaded and stored, and its ID will be returned.
- If only a file is provided, the document will be uploaded and stored, and its ID will be returned. The maximum file size allowed for upload is 25 MB.
- If a target format is specified, a transformation job will also be created to transform document into target format in the background. The response will include both the uploaded document details and information about the transformation job.
- If a callback URL is provided, you will receive a notification at that URL once the document transformation job is completed.

Expand Down
212 changes: 78 additions & 134 deletions backend/app/celery/tasks/job_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import celery
from asgi_correlation_id import correlation_id
from celery import current_task
from opentelemetry import context as otel_context
from opentelemetry import trace
from opentelemetry.propagate import extract

from app.celery.celery_app import celery_app
from app.celery.utils import gevent_timeout
Expand All @@ -20,61 +17,18 @@ def _set_trace(trace_id: str) -> None:
logger.info(f"[_set_trace] Set correlation ID: {trace_id}")


def _extract_parent_context(task_instance) -> otel_context.Context:
"""Extract OTel parent context from Celery headers if available."""
headers = getattr(task_instance.request, "headers", None) or {}
carrier: dict[str, str] = {}

if isinstance(headers, dict):
for key, value in headers.items():
if isinstance(value, str):
carrier[str(key)] = value

nested = headers.get("otel", {})
if isinstance(nested, dict):
for key, value in nested.items():
if isinstance(value, str):
carrier[str(key)] = value

return extract(carrier)


def _run_with_otel_parent(task_instance, fn):
"""Attach extracted parent context and execute function.

When Celery auto-instrumentation is active, there is already a current
`run/...` span. Re-attaching extracted parent context here would make
service spans become siblings of `run/...` instead of children.

We only attach extracted context as a fallback when no active span exists.
"""
current_ctx = trace.get_current_span().get_span_context()
if current_ctx and current_ctx.is_valid:
return fn()

parent_ctx = _extract_parent_context(task_instance)
token = otel_context.attach(parent_ctx)
try:
return fn()
finally:
otel_context.detach(token)


@celery_app.task(bind=True, queue="high_priority", priority=9)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job")
def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.llm.jobs import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -84,15 +38,12 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg
from app.services.llm.jobs import execute_chain_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_chain_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_chain_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -102,15 +53,12 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs
from app.services.response.jobs import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -120,15 +68,12 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw
from app.services.doctransform.job import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -137,18 +82,32 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw
def run_create_collection_job(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
from app.services.collections.create_collection import execute_job
from app.services.collections.create_collection import execute_setup_job

_set_trace(trace_id)
return execute_setup_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job")
def run_collection_batch_job(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
from app.services.collections.create_collection import execute_batch_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_batch_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -160,15 +119,12 @@ def run_delete_collection_job(
from app.services.collections.delete_collection import execute_job

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_job(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -180,15 +136,12 @@ def run_stt_batch_submission(
from app.services.stt_evaluations.batch_job import execute_batch_submission

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -200,15 +153,12 @@ def run_stt_metric_computation(
from app.services.stt_evaluations.metric_job import execute_metric_computation

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_metric_computation(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_metric_computation(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -220,15 +170,12 @@ def run_tts_batch_submission(
from app.services.tts_evaluations.batch_job import execute_batch_submission

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_batch_submission(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)


Expand All @@ -242,13 +189,10 @@ def run_tts_result_processing(
)

_set_trace(trace_id)
return _run_with_otel_parent(
self,
lambda: execute_tts_result_processing(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
),
return execute_tts_result_processing(
project_id=project_id,
job_id=job_id,
task_id=current_task.request.id,
task_instance=self,
**kwargs,
)
Loading
Loading