Skip to content
14 changes: 14 additions & 0 deletions backend/app/celery/tasks/job_execution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from typing import Any

import celery
from asgi_correlation_id import correlation_id
from celery import current_task
from opentelemetry import context as otel_context
from opentelemetry import trace
from opentelemetry.propagate import extract

from app.celery.celery_app import celery_app
from app.celery.utils import gevent_timeout
from app.core.config import settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,6 +61,7 @@ def _run_with_otel_parent(task_instance, fn):


@celery_app.task(bind=True, queue="high_priority", priority=9)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job")
def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.llm.jobs import execute_job

Expand All @@ -74,6 +79,7 @@ def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):


@celery_app.task(bind=True, queue="high_priority", priority=9)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_chain_job")
def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.llm.jobs import execute_chain_job

Expand All @@ -91,6 +97,7 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg


@celery_app.task(bind=True, queue="high_priority", priority=9)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_response_job")
def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.response.jobs import execute_job

Expand All @@ -108,6 +115,7 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_doctransform_job")
def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kwargs):
from app.services.doctransform.job import execute_job

Expand All @@ -125,6 +133,7 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_create_collection_job")
def run_create_collection_job(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
Expand All @@ -144,6 +153,7 @@ def run_create_collection_job(


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_delete_collection_job")
def run_delete_collection_job(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
Expand All @@ -163,6 +173,7 @@ def run_delete_collection_job(


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_stt_batch_submission")
def run_stt_batch_submission(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
Expand All @@ -182,6 +193,7 @@ def run_stt_batch_submission(


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_stt_metric_computation")
def run_stt_metric_computation(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
Expand All @@ -201,6 +213,7 @@ def run_stt_metric_computation(


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_batch_submission")
def run_tts_batch_submission(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
Expand All @@ -220,6 +233,7 @@ def run_tts_batch_submission(


@celery_app.task(bind=True, queue="low_priority", priority=1)
@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_result_processing")
def run_tts_result_processing(
self, project_id: int, job_id: str, trace_id: str, **kwargs
):
Expand Down
37 changes: 36 additions & 1 deletion backend/app/celery/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
Business logic modules can use these functions without knowing Celery internals.
"""
import logging
from typing import Any, Dict
import functools
from typing import Any, Dict, TypeVar
from collections.abc import Callable

from celery.result import AsyncResult
from gevent import Timeout
from opentelemetry.propagate import inject

from app.celery.celery_app import celery_app

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])


def _enqueue_with_trace_context(task, **kwargs) -> str:
"""Publish Celery task with explicit trace context headers."""
Expand Down Expand Up @@ -211,3 +216,33 @@ def revoke_task(task_id: str, terminate: bool = False) -> bool:
except Exception as e:
logger.error(f"[revoke_task] Failed to revoke task {task_id}: {e}")
return False


def gevent_timeout(
seconds: float | None, task_name: str | None = None
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
name = task_name or func.__name__
timeout = Timeout(seconds)
timeout.start()
try:
return func(*args, **kwargs)
except Timeout as err:
if err is not timeout:
raise
logger.error(f"[{name}] Timed out after {seconds}s")
raise
finally:
timeout.cancel()

return wrapper # type: ignore[return-value]

return decorator


# In gevent mode, Celery's soft and hard time limits fire during task cleanup,
# producing a misleading "Hard time limit exceeded" log. The task has already
# completed at this point (Pool POST fires first). This is a known gevent/Celery
# interaction and is harmless.
95 changes: 70 additions & 25 deletions backend/app/services/collections/create_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from opentelemetry import trace
from sqlmodel import Session
from celery.exceptions import SoftTimeLimitExceeded
from gevent import Timeout
from asgi_correlation_id import correlation_id

from app.core.cloud import get_cloud_storage
Expand Down Expand Up @@ -146,6 +148,44 @@ def _mark_job_failed(
return None


def _handle_job_failure(
span,
project_id: int,
organization_id: int,
job_id: str,
err: Exception,
collection_job: CollectionJob | None,
creation_request: CreationRequest | None,
provider=None,
result=None,
) -> None:
"""Record failure on span, clean up provider, mark job failed, and send failure callback."""
span.record_exception(err)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(err)))

if provider is not None and result is not None:
try:
provider.delete(result)
except Exception:
logger.warning("[create_collection.execute_job] Provider cleanup failed")

collection_job = _mark_job_failed(
project_id=project_id,
job_id=job_id,
err=err,
collection_job=collection_job,
)

if creation_request and creation_request.callback_url and collection_job:
failure_payload = build_failure_payload(collection_job, str(err))
webhook_secret = get_webhook_secret(project_id, organization_id)
send_callback(
str(creation_request.callback_url),
failure_payload,
webhook_secret=webhook_secret,
)
Comment on lines +179 to +186
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't let failure-webhook errors replace the real job failure.

send_callback() is unguarded here, unlike the _handle_job_failure pattern in backend/app/services/doctransform/job.py:107-145. If the webhook send fails, this helper raises a second exception and masks the original timeout/provider failure.

Suggested fix
     if creation_request and creation_request.callback_url and collection_job:
-        failure_payload = build_failure_payload(collection_job, str(err))
-        webhook_secret = get_webhook_secret(project_id, organization_id)
-        send_callback(
-            str(creation_request.callback_url),
-            failure_payload,
-            webhook_secret=webhook_secret,
-        )
+        try:
+            failure_payload = build_failure_payload(collection_job, str(err))
+            webhook_secret = get_webhook_secret(project_id, organization_id)
+            send_callback(
+                str(creation_request.callback_url),
+                failure_payload,
+                webhook_secret=webhook_secret,
+            )
+        except Exception as cb_error:
+            logger.error(
+                "[create_collection.execute_job] Failure callback failed | "
+                "{'collection_job_id': '%s', 'error': '%s'}",
+                job_id,
+                str(cb_error),
+                exc_info=True,
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if creation_request and creation_request.callback_url and collection_job:
failure_payload = build_failure_payload(collection_job, str(err))
webhook_secret = get_webhook_secret(project_id, organization_id)
send_callback(
str(creation_request.callback_url),
failure_payload,
webhook_secret=webhook_secret,
)
if creation_request and creation_request.callback_url and collection_job:
try:
failure_payload = build_failure_payload(collection_job, str(err))
webhook_secret = get_webhook_secret(project_id, organization_id)
send_callback(
str(creation_request.callback_url),
failure_payload,
webhook_secret=webhook_secret,
)
except Exception as cb_error:
logger.error(
"[create_collection.execute_job] Failure callback failed | "
"{'collection_job_id': '%s', 'error': '%s'}",
job_id,
str(cb_error),
exc_info=True,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@backend/app/services/collections/create_collection.py` around lines 178 -
185, The current failure-path calls send_callback(...) directly which can raise
and mask the original job failure; wrap the send_callback call in a try/except
(mirroring the _handle_job_failure pattern) so any exception from send_callback
is caught and logged but not re-raised, ensuring the original error remains the
primary failure. Specifically, in the block that builds failure_payload with
build_failure_payload(...) and fetches webhook_secret via
get_webhook_secret(...), call send_callback(...) inside a try/except, log the
send error (including exception details and context like
creation_request.callback_url and collection_job id) and swallow the exception
so it cannot override the original failure handling.



def execute_job(
request: dict,
with_assistant: bool,
Expand Down Expand Up @@ -281,37 +321,42 @@ def execute_job(
webhook_secret=webhook_secret,
)

except (Timeout, SoftTimeLimitExceeded) as err:
timeout_err = TimeoutError("Task exceeded soft time limit")
logger.warning(
"[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}",
job_id,
str(timeout_err),
)
_handle_job_failure(
span,
project_id,
organization_id,
job_id,
timeout_err,
collection_job,
creation_request,
provider,
result,
)
raise
Comment thread
coderabbitai[bot] marked this conversation as resolved.

except Exception as err:
span.record_exception(err)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(err)))
logger.error(
"[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}",
job_id,
str(err),
exc_info=True,
)

if provider is not None and result is not None:
try:
provider.delete(result)
except Exception:
logger.warning(
"[create_collection.execute_job] Provider cleanup failed"
)

collection_job = _mark_job_failed(
project_id=project_id,
job_id=job_id,
err=err,
collection_job=collection_job,
_handle_job_failure(
span,
project_id,
organization_id,
job_id,
err,
collection_job,
creation_request,
provider,
result,
)

if creation_request and creation_request.callback_url and collection_job:
failure_payload = build_failure_payload(collection_job, str(err))
webhook_secret = get_webhook_secret(project_id, organization_id)
send_callback(
str(creation_request.callback_url),
failure_payload,
webhook_secret=webhook_secret,
)
raise
22 changes: 22 additions & 0 deletions backend/app/services/collections/delete_collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from uuid import UUID

from gevent import Timeout
from celery.exceptions import SoftTimeLimitExceeded
from opentelemetry import trace
from sqlmodel import Session
from asgi_correlation_id import correlation_id
Expand Down Expand Up @@ -245,6 +247,26 @@ def execute_job(
webhook_secret=webhook_secret,
)

except (Timeout, SoftTimeLimitExceeded) as err:
timeout_err = TimeoutError("Task exceeded soft time limit")
logger.warning(
"[delete_collection.execute_job] Collection Deletion Timed Out | "
"{'collection_id': '%s', 'job_id': '%s'}",
str(collection_uuid),
str(job_uuid),
)
span.record_exception(timeout_err)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(timeout_err)))
_mark_job_failed_and_callback(
organization_id=organization_id,
project_id=project_id,
collection_id=collection_uuid,
job_id=job_uuid,
err=timeout_err,
callback_url=callback_url,
)
raise
Comment thread
coderabbitai[bot] marked this conversation as resolved.

except Exception as err:
span.record_exception(err)
span.set_status(trace.Status(trace.StatusCode.ERROR, str(err)))
Expand Down
Loading
Loading