From 9b24343e731aea227a2e97d5eaea4cd2ba4d5776 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Thu, 30 Apr 2026 01:50:16 +0530 Subject: [PATCH 1/9] gevent: implement soft time limit --- backend/app/celery/tasks/job_execution.py | 224 +++++++----------- backend/app/celery/utils.py | 24 ++ .../services/collections/create_collection.py | 96 ++++++-- .../services/collections/delete_collection.py | 23 ++ backend/app/services/doctransform/job.py | 97 +++++--- backend/app/services/llm/jobs.py | 69 ++++-- .../app/services/stt_evaluations/batch_job.py | 76 +++--- .../services/stt_evaluations/metric_job.py | 16 ++ .../app/services/tts_evaluations/batch_job.py | 114 +++++---- .../batch_result_processing.py | 16 ++ 10 files changed, 470 insertions(+), 285 deletions(-) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 8dd20091a..32226677b 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -2,11 +2,10 @@ from asgi_correlation_id import correlation_id from celery import current_task -from opentelemetry import context as otel_context -from opentelemetry import trace -from opentelemetry.propagate import extract from app.celery.celery_app import celery_app +from app.celery.utils import gevent_timeout +from app.core.config import settings logger = logging.getLogger(__name__) @@ -16,210 +15,170 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") -def _extract_parent_context(task_instance) -> otel_context.Context: - """Extract OTel parent context from Celery headers if available.""" - headers = getattr(task_instance.request, "headers", None) or {} - carrier: dict[str, str] = {} - - if isinstance(headers, dict): - for key, value in headers.items(): - if isinstance(value, str): - carrier[str(key)] = value - - nested = headers.get("otel", {}) - if isinstance(nested, dict): - for key, value in nested.items(): - if isinstance(value, str): - carrier[str(key)] = value - - return extract(carrier) - - -def _run_with_otel_parent(task_instance, fn): - """Attach extracted parent context and execute function. - - When Celery auto-instrumentation is active, there is already a current - `run/...` span. Re-attaching extracted parent context here would make - service spans become siblings of `run/...` instead of children. - - We only attach extracted context as a fallback when no active span exists. - """ - current_ctx = trace.get_current_span().get_span_context() - if current_ctx and current_ctx.is_valid: - return fn() - - parent_ctx = _extract_parent_context(task_instance) - token = otel_context.attach(parent_ctx) - try: - return fn() - finally: - otel_context.detach(token) - - @celery_app.task(bind=True, queue="high_priority", priority=9) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job") def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="high_priority", priority=9) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_chain_job") def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="high_priority", priority=9) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_response_job") def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.response.jobs import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_doctransform_job") def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.doctransform.job import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_create_collection_job") def run_create_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_job + from app.services.collections.create_collection import execute_setup_job + + _set_trace(trace_id) + return execute_setup_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ) + + +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") +def run_collection_batch_job( + self, project_id: int, job_id: str, trace_id: str, **kwargs +): + from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_delete_collection_job") def run_delete_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_stt_batch_submission") def run_stt_batch_submission( self, project_id: int, job_id: str, trace_id: str, **kwargs ): from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_stt_metric_computation") def run_stt_metric_computation( self, project_id: int, job_id: str, trace_id: str, **kwargs ): from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_batch_submission") def run_tts_batch_submission( self, project_id: int, job_id: str, trace_id: str, **kwargs ): from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) @celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_result_processing") def run_tts_result_processing( self, project_id: int, job_id: str, trace_id: str, **kwargs ): @@ -228,13 +187,10 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), + return execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, ) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 5ebbf624a..60af1905d 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -3,9 +3,11 @@ Business logic modules can use these functions without knowing Celery internals. """ import logging +import functools from typing import Any, Dict from celery.result import AsyncResult +from gevent import Timeout from opentelemetry.propagate import inject from app.celery.celery_app import celery_app @@ -211,3 +213,25 @@ def revoke_task(task_id: str, terminate: bool = False) -> bool: except Exception as e: logger.error(f"[revoke_task] Failed to revoke task {task_id}: {e}") return False + + +def gevent_timeout(seconds, task_name=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + name = task_name or func.__name__ + timeout = Timeout(seconds) + timeout.start() + try: + return func(*args, **kwargs) + except Timeout: + logger.error( + f"[{name}] Timed out after {seconds}s — args={args}, kwargs={kwargs}" + ) + raise + finally: + timeout.cancel() + + return wrapper + + return decorator diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 4acce89e1..2b52ac06a 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -4,6 +4,7 @@ from opentelemetry import trace from sqlmodel import Session +from gevent import Timeout from asgi_correlation_id import correlation_id from app.core.cloud import get_cloud_storage @@ -146,6 +147,44 @@ def _mark_job_failed( return None +def _handle_job_failure( + span, + project_id: int, + organization_id: int, + job_id: str, + err: Exception, + collection_job: CollectionJob | None, + creation_request: CreationRequest | None, + provider=None, + result=None, +) -> None: + """Record failure on span, clean up provider, mark job failed, and send failure callback.""" + span.record_exception(err) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(err))) + + if provider is not None and result is not None: + try: + provider.delete(result) + except Exception: + logger.warning("[create_collection.execute_job] Provider cleanup failed") + + collection_job = _mark_job_failed( + project_id=project_id, + job_id=job_id, + err=err, + collection_job=collection_job, + ) + + if creation_request and creation_request.callback_url and collection_job: + failure_payload = build_failure_payload(collection_job, str(err)) + webhook_secret = get_webhook_secret(project_id, organization_id) + send_callback( + str(creation_request.callback_url), + failure_payload, + webhook_secret=webhook_secret, + ) + + def execute_job( request: dict, with_assistant: bool, @@ -281,37 +320,44 @@ def execute_job( webhook_secret=webhook_secret, ) + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_setup_job] Task exceeded soft time limit of {err.seconds}s" + ) + logger.error( + "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", + job_id, + str(timeout_err), + ) + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + timeout_err, + collection_job, + creation_request, + provider, + result, + ) + raise + except Exception as err: - span.record_exception(err) - span.set_status(trace.Status(trace.StatusCode.ERROR, str(err))) logger.error( "[create_collection.execute_job] Collection Creation Failed | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(err), exc_info=True, ) - - if provider is not None and result is not None: - try: - provider.delete(result) - except Exception: - logger.warning( - "[create_collection.execute_job] Provider cleanup failed" - ) - - collection_job = _mark_job_failed( - project_id=project_id, - job_id=job_id, - err=err, - collection_job=collection_job, + _handle_job_failure( + span, + project_id, + organization_id, + job_id, + err, + collection_job, + creation_request, + provider, + result, ) - - if creation_request and creation_request.callback_url and collection_job: - failure_payload = build_failure_payload(collection_job, str(err)) - webhook_secret = get_webhook_secret(project_id, organization_id) - send_callback( - str(creation_request.callback_url), - failure_payload, - webhook_secret=webhook_secret, - ) raise diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index 1c8e8a497..1a5ed3353 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -1,6 +1,7 @@ import logging from uuid import UUID +from gevent import Timeout from opentelemetry import trace from sqlmodel import Session from asgi_correlation_id import correlation_id @@ -245,6 +246,28 @@ def execute_job( webhook_secret=webhook_secret, ) + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_job] Collection deletion exceeded soft time limit: {err}" + ) + logger.error( + "[delete_collection.execute_job] Collection Deletion Timed Out | " + "{'collection_id': '%s', 'job_id': '%s'}", + str(collection_uuid), + str(job_uuid), + ) + span.record_exception(timeout_err) + span.set_status(trace.Status(trace.StatusCode.ERROR, str(timeout_err))) + _mark_job_failed_and_callback( + organization_id=organization_id, + project_id=project_id, + collection_id=collection_uuid, + job_id=job_uuid, + err=timeout_err, + callback_url=callback_url, + ) + raise + except Exception as err: span.record_exception(err) span.set_status(trace.Status(trace.StatusCode.ERROR, str(err))) diff --git a/backend/app/services/doctransform/job.py b/backend/app/services/doctransform/job.py index 62ba9b240..dc64a1494 100644 --- a/backend/app/services/doctransform/job.py +++ b/backend/app/services/doctransform/job.py @@ -6,6 +6,7 @@ from uuid import uuid4, UUID from fastapi import UploadFile +from gevent import Timeout from tenacity import retry, wait_exponential, stop_after_attempt from sqlmodel import Session from asgi_correlation_id import correlation_id @@ -103,6 +104,46 @@ def build_failure_payload(job: DocTransformationJob, error_message: str): ) +def _handle_job_failure( + error_message: str, + job_uuid: UUID, + project_id: int, + callback_url: str | None, + webhook_secret: str | None, + job_for_payload: DocTransformationJob | None, + log_context: str = "", +) -> DocTransformationJob | None: + context_suffix = f" on {log_context}" if log_context else "" + try: + with Session(engine) as db: + job_crud = DocTransformationJobCrud(session=db, project_id=project_id) + job_for_payload = job_crud.update( + job_uuid, + DocTransformJobUpdate( + status=TransformationStatus.FAILED, error_message=error_message + ), + ) + except Exception as db_error: + logger.error( + "[doc_transform.execute_job] failed to persist FAILED status%s | job_id=%s | db_error=%s", + context_suffix, + job_uuid, + db_error, + ) + if callback_url and job_for_payload: + try: + failure_payload = build_failure_payload(job_for_payload, error_message) + send_callback(callback_url, failure_payload, webhook_secret=webhook_secret) + except Exception as cb_error: + logger.error( + "[doc_transform.execute_job] callback failed%s | job_id=%s | error=%s", + context_suffix, + job_uuid, + cb_error, + ) + return job_for_payload + + @retry(wait=wait_exponential(multiplier=5, min=5, max=10), stop=stop_after_attempt(3)) def execute_job( project_id: int, @@ -120,8 +161,9 @@ def execute_job( job_for_payload = None # keep latest job snapshot for payloads webhook_secret: str | None = None + job_uuid = UUID(job_id) + try: - job_uuid = UUID(job_id) source_uuid = UUID(source_document_id) if callback_url: @@ -232,6 +274,25 @@ def execute_job( if callback_url: send_callback(callback_url, success_payload, webhook_secret=webhook_secret) + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_job] Document transformation exceeded soft time limit: {err}" + ) + logger.error( + "[doc_transform.execute_job] Timed Out | job_id=%s", + job_uuid, + ) + _handle_job_failure( + str(timeout_err), + job_uuid, + project_id, + callback_url, + webhook_secret, + job_for_payload, + log_context="timeout", + ) + raise + except Exception as e: logger.error( "[doc_transform.execute_job] FAILED | job_id=%s | error=%s", @@ -239,37 +300,9 @@ def execute_job( e, exc_info=True, ) - - try: - with Session(engine) as db: - job_crud = DocTransformationJobCrud(session=db, project_id=project_id) - job_for_payload = job_crud.update( - job_uuid, - DocTransformJobUpdate( - status=TransformationStatus.FAILED, error_message=str(e) - ), - ) - except Exception as db_error: - logger.error( - "[doc_transform.execute_job] failed to persist FAILED status | job_id=%s | db_error=%s", - job_uuid, - db_error, - ) - - if callback_url and job_for_payload: - try: - failure_payload = build_failure_payload(job_for_payload, str(e)) - send_callback( - callback_url, failure_payload, webhook_secret=webhook_secret - ) - except Exception as cb_error: - logger.error( - "[doc_transform.execute_job] callback failed | job_id=%s | error=%s", - job_uuid, - cb_error, - ) - - # bubble up for caller/infra + _handle_job_failure( + str(e), job_uuid, project_id, callback_url, webhook_secret, job_for_payload + ) raise finally: if tmp_dir and tmp_dir.exists(): diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index e797040a2..ef31ce9a8 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -6,6 +6,7 @@ from asgi_correlation_id import correlation_id from fastapi import HTTPException +from gevent import Timeout from opentelemetry import trace from sqlmodel import Session @@ -186,6 +187,7 @@ def handle_job_error( callback_response: APIResponse, organization_id: int | None = None, project_id: int | None = None, + chain_id: UUID | None = None, ) -> dict: """Handle job failure uniformly — send callback and update DB.""" if callback_url: @@ -207,6 +209,20 @@ def handle_job_error( error_message=callback_response.error, ), ) + if chain_id: + try: + update_llm_chain_status( + session, + chain_id=chain_id, + status=ChainStatus.FAILED, + error=callback_response.error, + ) + except Exception as update_err: + logger.error( + f"[handle_job_error] Failed to update chain status: {update_err} | " + f"chain_id={chain_id}", + exc_info=True, + ) return callback_response.model_dump() @@ -820,6 +836,23 @@ def execute_job( project_id=project_id, ) + except Timeout: + logger.error( + f"[execute_job] LLM job timed out | job_id={job_uuid}, task_id={task_id}" + ) + callback_response = APIResponse.failure_response( + error="LLM job exceeded soft time limit", + metadata=request.request_metadata, + ) + handle_job_error( + job_uuid, + callback_url_str, + callback_response, + organization_id=organization_id, + project_id=project_id, + ) + raise + except Exception as e: callback_response = APIResponse.failure_response( error="Unexpected error occurred", @@ -933,28 +966,29 @@ def execute_chain_job( executor = ChainExecutor(chain=chain, context=context, request=request) return executor.run() + except Timeout as err: + logger.error( + f"[execute_chain_job] Chain job timed out | job_id={job_uuid}, task_id={task_id}" + ) + callback_response = APIResponse.failure_response( + error=f"[execute_chain_job] Chain job exceeded soft time limit: {err}", + metadata=request.request_metadata, + ) + handle_job_error( + job_uuid, + callback_url_str, + callback_response, + organization_id=organization_id, + project_id=project_id, + chain_id=chain_uuid, + ) + raise + except Exception as e: logger.error( f"[execute_chain_job] Failed: {e} | job_id={job_uuid}", exc_info=True, ) - - if chain_uuid: - try: - with Session(engine) as session: - update_llm_chain_status( - session, - chain_id=chain_uuid, - status=ChainStatus.FAILED, - error=str(e), - ) - except Exception as update_err: - logger.error( - f"[execute_chain_job] Failed to update chain status: {update_err} | " - f"chain_id={chain_uuid}", - exc_info=True, - ) - callback_response = APIResponse.failure_response( error="Unexpected error occurred", metadata=request.request_metadata, @@ -965,6 +999,7 @@ def execute_chain_job( callback_response, organization_id=organization_id, project_id=project_id, + chain_id=chain_uuid, ) finally: # Ensure task spans are pushed promptly so Sentry dashboards update faster. diff --git a/backend/app/services/stt_evaluations/batch_job.py b/backend/app/services/stt_evaluations/batch_job.py index 69648dc21..9586c5d33 100644 --- a/backend/app/services/stt_evaluations/batch_job.py +++ b/backend/app/services/stt_evaluations/batch_job.py @@ -2,6 +2,7 @@ import logging +from gevent import Timeout from sqlmodel import Session from app.core.db import engine @@ -46,39 +47,41 @@ def execute_batch_submission( ) with Session(engine) as session: - run = get_stt_run_by_id( - session=session, - run_id=run_id, - org_id=organization_id, - project_id=project_id, - ) - - if not run: - logger.error(f"[execute_batch_submission] Run not found | run_id: {run_id}") - return {"success": False, "error": "Run not found"} - - samples = get_samples_by_dataset_id( - session=session, - dataset_id=dataset_id, - org_id=organization_id, - project_id=project_id, - limit=run.total_items, - ) - - if not samples: - logger.error( - f"[execute_batch_submission] No samples found | " - f"run_id: {run_id}, dataset_id: {dataset_id}" - ) - update_stt_run( + try: + run = get_stt_run_by_id( session=session, run_id=run_id, - status="failed", - error_message="No samples found for dataset", + org_id=organization_id, + project_id=project_id, ) - return {"success": False, "error": "No samples found"} - try: + if not run: + logger.error( + f"[execute_batch_submission] Run not found | run_id: {run_id}" + ) + return {"success": False, "error": "Run not found"} + + samples = get_samples_by_dataset_id( + session=session, + dataset_id=dataset_id, + org_id=organization_id, + project_id=project_id, + limit=run.total_items, + ) + + if not samples: + logger.error( + f"[execute_batch_submission] No samples found | " + f"run_id: {run_id}, dataset_id: {dataset_id}" + ) + update_stt_run( + session=session, + run_id=run_id, + status="failed", + error_message="No samples found for dataset", + ) + return {"success": False, "error": "No samples found"} + batch_result = start_stt_evaluation_batch( session=session, run=run, @@ -95,6 +98,21 @@ def execute_batch_submission( return batch_result + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_batch_submission] STT batch submission exceeded soft time limit: {err}" + ) + logger.error( + f"[execute_batch_submission] STT batch submission timed out | run_id={run_id}" + ) + update_stt_run( + session=session, + run_id=run_id, + status="failed", + error_message=str(timeout_err), + ) + raise + except Exception as e: logger.error( f"[execute_batch_submission] Batch submission failed | " diff --git a/backend/app/services/stt_evaluations/metric_job.py b/backend/app/services/stt_evaluations/metric_job.py index 3d1e272a7..abec245a4 100644 --- a/backend/app/services/stt_evaluations/metric_job.py +++ b/backend/app/services/stt_evaluations/metric_job.py @@ -3,6 +3,7 @@ import logging from typing import Any +from gevent import Timeout from sqlalchemy import update from sqlmodel import Session, select @@ -153,6 +154,21 @@ def execute_metric_computation( "failed": failed_count, } + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_metric_computation] STT metric computation exceeded soft time limit: {err}" + ) + logger.error( + f"[execute_metric_computation] STT metric computation timed out | run_id={run_id}" + ) + update_stt_run( + session=session, + run_id=run_id, + status="failed", + error_message=str(timeout_err), + ) + raise + except Exception as e: logger.error( f"[execute_metric_computation] Failed | " diff --git a/backend/app/services/tts_evaluations/batch_job.py b/backend/app/services/tts_evaluations/batch_job.py index fc89500a9..40382b455 100644 --- a/backend/app/services/tts_evaluations/batch_job.py +++ b/backend/app/services/tts_evaluations/batch_job.py @@ -2,6 +2,7 @@ import logging +from gevent import Timeout from sqlmodel import Session from app.core.db import engine @@ -50,63 +51,65 @@ def execute_batch_submission( ) with Session(engine) as session: - run = get_tts_run_by_id( - session=session, - run_id=run_id, - org_id=organization_id, - project_id=project_id, - ) - - if not run: - logger.error(f"[execute_batch_submission] Run not found | run_id: {run_id}") - return {"success": False, "error": "Run not found"} - - dataset = get_tts_dataset_by_id( - session=session, - dataset_id=dataset_id, - org_id=organization_id, - project_id=project_id, - ) - - if not dataset: - logger.error( - f"[execute_batch_submission] Dataset not found | " - f"run_id: {run_id}, dataset_id: {dataset_id}" - ) - update_tts_run( + try: + run = get_tts_run_by_id( session=session, run_id=run_id, - status="failed", - error_message="Dataset not found", + org_id=organization_id, + project_id=project_id, ) - return {"success": False, "error": "Dataset not found"} - sample_texts = get_sample_texts_from_dataset(session, dataset, project_id) + if not run: + logger.error( + f"[execute_batch_submission] Run not found | run_id: {run_id}" + ) + return {"success": False, "error": "Run not found"} - if not sample_texts: - logger.error( - f"[execute_batch_submission] No samples found | " - f"run_id: {run_id}, dataset_id: {dataset_id}" + dataset = get_tts_dataset_by_id( + session=session, + dataset_id=dataset_id, + org_id=organization_id, + project_id=project_id, ) - update_tts_run( + + if not dataset: + logger.error( + f"[execute_batch_submission] Dataset not found | " + f"run_id: {run_id}, dataset_id: {dataset_id}" + ) + update_tts_run( + session=session, + run_id=run_id, + status="failed", + error_message="Dataset not found", + ) + return {"success": False, "error": "Dataset not found"} + + sample_texts = get_sample_texts_from_dataset(session, dataset, project_id) + + if not sample_texts: + logger.error( + f"[execute_batch_submission] No samples found | " + f"run_id: {run_id}, dataset_id: {dataset_id}" + ) + update_tts_run( + session=session, + run_id=run_id, + status="failed", + error_message="No samples found for dataset", + ) + return {"success": False, "error": "No samples found"} + + # Create result records for each sample text and model + results = create_tts_results( session=session, - run_id=run_id, - status="failed", - error_message="No samples found for dataset", + sample_texts=sample_texts, + evaluation_run_id=run.id, + org_id=organization_id, + project_id=project_id, + models=models, ) - return {"success": False, "error": "No samples found"} - - # Create result records for each sample text and model - results = create_tts_results( - session=session, - sample_texts=sample_texts, - evaluation_run_id=run.id, - org_id=organization_id, - project_id=project_id, - models=models, - ) - try: batch_result = start_tts_evaluation_batch( session=session, run=run, @@ -123,6 +126,21 @@ def execute_batch_submission( return batch_result + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_batch_submission] TTS batch submission exceeded soft time limit: {err}" + ) + logger.error( + f"[execute_batch_submission] TTS batch submission timed out | run_id={run_id}" + ) + update_tts_run( + session=session, + run_id=run_id, + status="failed", + error_message=str(timeout_err), + ) + raise + except Exception as e: logger.error( f"[execute_batch_submission] Batch submission failed | " diff --git a/backend/app/services/tts_evaluations/batch_result_processing.py b/backend/app/services/tts_evaluations/batch_result_processing.py index 390945514..5dd673a37 100644 --- a/backend/app/services/tts_evaluations/batch_result_processing.py +++ b/backend/app/services/tts_evaluations/batch_result_processing.py @@ -9,6 +9,7 @@ import uuid from typing import Any +from gevent import Timeout from sqlmodel import Session, select from app.core.batch import BATCH_KEY, GeminiBatchProvider, GeminiClient @@ -239,6 +240,21 @@ def execute_tts_result_processing( "run_status": final_status, } + except Timeout as err: + timeout_err = TimeoutError( + f"[execute_tts_result_processing] TTS result processing exceeded soft time limit: {err}" + ) + logger.error( + f"[execute_tts_result_processing] TTS result processing timed out | run_id={evaluation_run_id}" + ) + update_tts_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message=str(timeout_err), + ) + raise + except Exception as e: logger.error( f"[execute_tts_result_processing] Failed | " From 8788ee8cc7db1cb52b6e2dabcab74da3439746d4 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Thu, 30 Apr 2026 09:28:26 +0530 Subject: [PATCH 2/9] coderabbit review and test cases --- backend/app/celery/tasks/job_execution.py | 10 ++++++---- backend/app/celery/utils.py | 4 +--- backend/app/services/collections/create_collection.py | 2 +- backend/app/services/llm/jobs.py | 2 +- backend/app/tests/services/llm/test_jobs.py | 8 ++------ 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 32226677b..548cd6ab0 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -1,5 +1,7 @@ import logging +from typing import Any +import celery from asgi_correlation_id import correlation_id from celery import current_task @@ -80,10 +82,10 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw def run_create_collection_job( self, project_id: int, job_id: str, trace_id: str, **kwargs ): - from app.services.collections.create_collection import execute_setup_job + from app.services.collections.create_collection import execute_job _set_trace(trace_id) - return execute_setup_job( + return execute_job( project_id=project_id, job_id=job_id, task_id=current_task.request.id, @@ -95,8 +97,8 @@ def run_create_collection_job( @celery_app.task(bind=True, queue="low_priority", priority=1) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") def run_collection_batch_job( - self, project_id: int, job_id: str, trace_id: str, **kwargs -): + self: celery.Task, project_id: int, job_id: str, trace_id: str, **kwargs: Any +) -> None: from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 60af1905d..5ec358801 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -225,9 +225,7 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Timeout: - logger.error( - f"[{name}] Timed out after {seconds}s — args={args}, kwargs={kwargs}" - ) + logger.error(f"[{name}] Timed out after {seconds}s") raise finally: timeout.cancel() diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 2b52ac06a..5ac68126c 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -322,7 +322,7 @@ def execute_job( except Timeout as err: timeout_err = TimeoutError( - f"[execute_setup_job] Task exceeded soft time limit of {err.seconds}s" + f"[execute_job] Task exceeded soft time limit of {err.seconds}s" ) logger.error( "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index ef31ce9a8..d97f25be9 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -971,7 +971,7 @@ def execute_chain_job( f"[execute_chain_job] Chain job timed out | job_id={job_uuid}, task_id={task_id}" ) callback_response = APIResponse.failure_response( - error=f"[execute_chain_job] Chain job exceeded soft time limit: {err}", + error="Chain job exceeded soft time limit", metadata=request.request_metadata, ) handle_job_error( diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index e838796ad..8912542f7 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1334,9 +1334,6 @@ def test_chain_status_updated_to_failed_on_error(self, chain_request_data): patch("app.services.llm.jobs.Session") as mock_session, patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, patch("app.services.llm.jobs.get_provider_credential") as mock_creds, - patch( - "app.services.llm.jobs.update_llm_chain_status" - ) as mock_update_status, patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, patch( "app.services.llm.chain.chain.LLMChain", @@ -1357,10 +1354,9 @@ def test_chain_status_updated_to_failed_on_error(self, chain_request_data): result = self._execute_chain_job(chain_request_data) - mock_update_status.assert_called_once() - _, kwargs = mock_update_status.call_args + mock_handle_error.assert_called_once() + _, kwargs = mock_handle_error.call_args assert kwargs["chain_id"] == chain_id - assert kwargs["status"].value == "FAILED" class TestResolveConfigBlob: From 7def1a43d609d2bf16a0f4b9434e0784d8b35837 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 4 May 2026 09:48:48 +0530 Subject: [PATCH 3/9] adding test cases and prefork --- backend/app/celery/tasks/job_execution.py | 208 ++++++++++++------ .../services/collections/create_collection.py | 7 +- .../services/collections/delete_collection.py | 7 +- backend/app/services/doctransform/job.py | 7 +- backend/app/services/doctransform/registry.py | 4 + .../doctransform/zerox_transformer.py | 4 + backend/app/services/llm/jobs.py | 9 +- .../app/services/stt_evaluations/batch_job.py | 7 +- .../services/stt_evaluations/metric_job.py | 7 +- .../app/services/tts_evaluations/batch_job.py | 7 +- .../batch_result_processing.py | 7 +- .../collections/test_create_collection.py | 93 ++++++++ .../collections/test_delete_collection.py | 110 +++++++++ .../test_job/test_execute_job_errors.py | 102 +++++++++ backend/app/tests/services/llm/test_jobs.py | 103 +++++++++ 15 files changed, 584 insertions(+), 98 deletions(-) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 548cd6ab0..aede450e2 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -4,6 +4,9 @@ import celery from asgi_correlation_id import correlation_id from celery import current_task +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.propagate import extract from app.celery.celery_app import celery_app from app.celery.utils import gevent_timeout @@ -17,18 +20,61 @@ def _set_trace(trace_id: str) -> None: logger.info(f"[_set_trace] Set correlation ID: {trace_id}") +def _extract_parent_context(task_instance) -> otel_context.Context: + """Extract OTel parent context from Celery headers if available.""" + headers = getattr(task_instance.request, "headers", None) or {} + carrier: dict[str, str] = {} + + if isinstance(headers, dict): + for key, value in headers.items(): + if isinstance(value, str): + carrier[str(key)] = value + + nested = headers.get("otel", {}) + if isinstance(nested, dict): + for key, value in nested.items(): + if isinstance(value, str): + carrier[str(key)] = value + + return extract(carrier) + + +def _run_with_otel_parent(task_instance, fn): + """Attach extracted parent context and execute function. + + When Celery auto-instrumentation is active, there is already a current + `run/...` span. Re-attaching extracted parent context here would make + service spans become siblings of `run/...` instead of children. + + We only attach extracted context as a fallback when no active span exists. + """ + current_ctx = trace.get_current_span().get_span_context() + if current_ctx and current_ctx.is_valid: + return fn() + + parent_ctx = _extract_parent_context(task_instance) + token = otel_context.attach(parent_ctx) + try: + return fn() + finally: + otel_context.detach(token) + + @celery_app.task(bind=True, queue="high_priority", priority=9) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_llm_job") def run_llm_job(self, project_id: int, job_id: str, trace_id: str, **kwargs): from app.services.llm.jobs import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -38,12 +84,15 @@ def run_llm_chain_job(self, project_id: int, job_id: str, trace_id: str, **kwarg from app.services.llm.jobs import execute_chain_job _set_trace(trace_id) - return execute_chain_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_chain_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -53,12 +102,15 @@ def run_response_job(self, project_id: int, job_id: str, trace_id: str, **kwargs from app.services.response.jobs import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -68,12 +120,15 @@ def run_doctransform_job(self, project_id: int, job_id: str, trace_id: str, **kw from app.services.doctransform.job import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -85,12 +140,15 @@ def run_create_collection_job( from app.services.collections.create_collection import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -102,12 +160,15 @@ def run_collection_batch_job( from app.services.collections.create_collection import execute_batch_job _set_trace(trace_id) - return execute_batch_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_batch_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -119,12 +180,15 @@ def run_delete_collection_job( from app.services.collections.delete_collection import execute_job _set_trace(trace_id) - return execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_job( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -136,12 +200,15 @@ def run_stt_batch_submission( from app.services.stt_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -153,12 +220,15 @@ def run_stt_metric_computation( from app.services.stt_evaluations.metric_job import execute_metric_computation _set_trace(trace_id) - return execute_metric_computation( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_metric_computation( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -170,12 +240,15 @@ def run_tts_batch_submission( from app.services.tts_evaluations.batch_job import execute_batch_submission _set_trace(trace_id) - return execute_batch_submission( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_batch_submission( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) @@ -189,10 +262,13 @@ def run_tts_result_processing( ) _set_trace(trace_id) - return execute_tts_result_processing( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, + return _run_with_otel_parent( + self, + lambda: execute_tts_result_processing( + project_id=project_id, + job_id=job_id, + task_id=current_task.request.id, + task_instance=self, + **kwargs, + ), ) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 5ac68126c..ba673d6a7 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -4,6 +4,7 @@ from opentelemetry import trace from sqlmodel import Session +from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout from asgi_correlation_id import correlation_id @@ -320,10 +321,8 @@ def execute_job( webhook_secret=webhook_secret, ) - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_job] Task exceeded soft time limit of {err.seconds}s" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", job_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index 1a5ed3353..68a3cea89 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -2,6 +2,7 @@ from uuid import UUID from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from opentelemetry import trace from sqlmodel import Session from asgi_correlation_id import correlation_id @@ -246,10 +247,8 @@ def execute_job( webhook_secret=webhook_secret, ) - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_job] Collection deletion exceeded soft time limit: {err}" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( "[delete_collection.execute_job] Collection Deletion Timed Out | " "{'collection_id': '%s', 'job_id': '%s'}", diff --git a/backend/app/services/doctransform/job.py b/backend/app/services/doctransform/job.py index dc64a1494..4f9fbb9b6 100644 --- a/backend/app/services/doctransform/job.py +++ b/backend/app/services/doctransform/job.py @@ -7,6 +7,7 @@ from fastapi import UploadFile from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from tenacity import retry, wait_exponential, stop_after_attempt from sqlmodel import Session from asgi_correlation_id import correlation_id @@ -274,10 +275,8 @@ def execute_job( if callback_url: send_callback(callback_url, success_payload, webhook_secret=webhook_secret) - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_job] Document transformation exceeded soft time limit: {err}" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( "[doc_transform.execute_job] Timed Out | job_id=%s", job_uuid, diff --git a/backend/app/services/doctransform/registry.py b/backend/app/services/doctransform/registry.py index 29c2fc251..46446332c 100644 --- a/backend/app/services/doctransform/registry.py +++ b/backend/app/services/doctransform/registry.py @@ -1,4 +1,6 @@ from pathlib import Path +from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from app.services.doctransform.transformer import Transformer from app.services.doctransform.zerox_transformer import ZeroxTransformer @@ -124,6 +126,8 @@ def convert_document( transformer = transformer_cls() try: return transformer.transform(input_path, output_path) + except (Timeout, SoftTimeLimitExceeded): + raise except Exception as e: raise TransformationError( f"Error applying transformer '{transformer_name}': {e}" diff --git a/backend/app/services/doctransform/zerox_transformer.py b/backend/app/services/doctransform/zerox_transformer.py index 321a6ba65..bbb1263a2 100644 --- a/backend/app/services/doctransform/zerox_transformer.py +++ b/backend/app/services/doctransform/zerox_transformer.py @@ -3,6 +3,8 @@ from asyncio import Runner, wait_for from pathlib import Path from pyzerox import zerox +from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from app.services.doctransform.transformer import Transformer @@ -33,6 +35,8 @@ def transform(self, input_path: Path, output_path: Path) -> Path: timeout=10 * 60, # 10 minutes ) ) + except (Timeout, SoftTimeLimitExceeded): + raise except TimeoutError: logger.error( f"ZeroxTransformer timed out for {input_path} (model={self.model})" diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index d97f25be9..39dc84b33 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -6,6 +6,7 @@ from asgi_correlation_id import correlation_id from fastapi import HTTPException +from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout from opentelemetry import trace from sqlmodel import Session @@ -836,12 +837,12 @@ def execute_job( project_id=project_id, ) - except Timeout: + except (Timeout, SoftTimeLimitExceeded): logger.error( f"[execute_job] LLM job timed out | job_id={job_uuid}, task_id={task_id}" ) callback_response = APIResponse.failure_response( - error="LLM job exceeded soft time limit", + error="Task exceeded soft time limit", metadata=request.request_metadata, ) handle_job_error( @@ -966,12 +967,12 @@ def execute_chain_job( executor = ChainExecutor(chain=chain, context=context, request=request) return executor.run() - except Timeout as err: + except (Timeout, SoftTimeLimitExceeded) as err: logger.error( f"[execute_chain_job] Chain job timed out | job_id={job_uuid}, task_id={task_id}" ) callback_response = APIResponse.failure_response( - error="Chain job exceeded soft time limit", + error="Task exceeded soft time limit", metadata=request.request_metadata, ) handle_job_error( diff --git a/backend/app/services/stt_evaluations/batch_job.py b/backend/app/services/stt_evaluations/batch_job.py index 9586c5d33..08e09dac1 100644 --- a/backend/app/services/stt_evaluations/batch_job.py +++ b/backend/app/services/stt_evaluations/batch_job.py @@ -3,6 +3,7 @@ import logging from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from sqlmodel import Session from app.core.db import engine @@ -98,10 +99,8 @@ def execute_batch_submission( return batch_result - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_batch_submission] STT batch submission exceeded soft time limit: {err}" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( f"[execute_batch_submission] STT batch submission timed out | run_id={run_id}" ) diff --git a/backend/app/services/stt_evaluations/metric_job.py b/backend/app/services/stt_evaluations/metric_job.py index abec245a4..effea895d 100644 --- a/backend/app/services/stt_evaluations/metric_job.py +++ b/backend/app/services/stt_evaluations/metric_job.py @@ -4,6 +4,7 @@ from typing import Any from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from sqlalchemy import update from sqlmodel import Session, select @@ -154,10 +155,8 @@ def execute_metric_computation( "failed": failed_count, } - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_metric_computation] STT metric computation exceeded soft time limit: {err}" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( f"[execute_metric_computation] STT metric computation timed out | run_id={run_id}" ) diff --git a/backend/app/services/tts_evaluations/batch_job.py b/backend/app/services/tts_evaluations/batch_job.py index 40382b455..51e060ea3 100644 --- a/backend/app/services/tts_evaluations/batch_job.py +++ b/backend/app/services/tts_evaluations/batch_job.py @@ -3,6 +3,7 @@ import logging from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from sqlmodel import Session from app.core.db import engine @@ -126,10 +127,8 @@ def execute_batch_submission( return batch_result - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_batch_submission] TTS batch submission exceeded soft time limit: {err}" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( f"[execute_batch_submission] TTS batch submission timed out | run_id={run_id}" ) diff --git a/backend/app/services/tts_evaluations/batch_result_processing.py b/backend/app/services/tts_evaluations/batch_result_processing.py index 5dd673a37..5394f2934 100644 --- a/backend/app/services/tts_evaluations/batch_result_processing.py +++ b/backend/app/services/tts_evaluations/batch_result_processing.py @@ -10,6 +10,7 @@ from typing import Any from gevent import Timeout +from celery.exceptions import SoftTimeLimitExceeded from sqlmodel import Session, select from app.core.batch import BATCH_KEY, GeminiBatchProvider, GeminiClient @@ -240,10 +241,8 @@ def execute_tts_result_processing( "run_status": final_status, } - except Timeout as err: - timeout_err = TimeoutError( - f"[execute_tts_result_processing] TTS result processing exceeded soft time limit: {err}" - ) + except (Timeout, SoftTimeLimitExceeded) as err: + timeout_err = TimeoutError(f"Task exceeded soft time limit") logger.error( f"[execute_tts_result_processing] TTS result processing timed out | run_id={evaluation_run_id}" ) diff --git a/backend/app/tests/services/collections/test_create_collection.py b/backend/app/tests/services/collections/test_create_collection.py index a2ffdfa5d..d8ca2829b 100644 --- a/backend/app/tests/services/collections/test_create_collection.py +++ b/backend/app/tests/services/collections/test_create_collection.py @@ -6,6 +6,8 @@ import uuid from uuid import UUID, uuid4 +from gevent import Timeout + import pytest from moto import mock_aws from sqlmodel import Session @@ -464,3 +466,94 @@ def test_execute_job_failure_flow_callback_job_and_marks_failed( assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED assert payload_arg["data"]["collection"] is None assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id + + +@patch("app.services.collections.create_collection.get_llm_provider") +def test_execute_job_timeout_marks_job_failed( + mock_get_llm_provider: MagicMock, db: Session +) -> None: + project = get_project(db) + + job = get_collection_job( + db, + project, + job_id=uuid4(), + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + mock_provider = MagicMock() + mock_provider.create.side_effect = Timeout(300) + mock_get_llm_provider.return_value = mock_provider + + req = CreationRequest(documents=[], callback_url=None, provider="openai") + + with patch("app.services.collections.create_collection.Session") as SessionCtor: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + with pytest.raises(Timeout): + execute_job( + request=req.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + with_assistant=False, + job_id=str(job.id), + task_instance=None, + ) + + updated_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert updated_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (updated_job.error_message or "") + + +@patch("app.services.collections.create_collection.get_llm_provider") +@patch("app.services.collections.create_collection.send_callback") +def test_execute_job_timeout_sends_failure_callback( + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, + db: Session, +) -> None: + project = get_project(db) + callback_url = "https://example.com/collections/timeout" + + job = get_collection_job( + db, + project, + job_id=uuid4(), + action_type=CollectionActionType.CREATE, + status=CollectionJobStatus.PENDING, + collection_id=None, + ) + + mock_provider = MagicMock() + mock_provider.create.side_effect = Timeout(300) + mock_get_llm_provider.return_value = mock_provider + + req = CreationRequest(documents=[], callback_url=callback_url, provider="openai") + + with patch("app.services.collections.create_collection.Session") as SessionCtor: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + with pytest.raises(Timeout): + execute_job( + request=req.model_dump(), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + with_assistant=False, + job_id=str(job.id), + task_instance=None, + ) + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is False + assert "soft time limit" in (payload_arg["error"] or "") + assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED + assert payload_arg["data"]["collection"] is None + assert uuid.UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/collections/test_delete_collection.py b/backend/app/tests/services/collections/test_delete_collection.py index 6c7bcd91b..42ee8360f 100644 --- a/backend/app/tests/services/collections/test_delete_collection.py +++ b/backend/app/tests/services/collections/test_delete_collection.py @@ -1,6 +1,8 @@ from unittest.mock import patch, MagicMock from uuid import uuid4, UUID +from gevent import Timeout + import pytest from sqlmodel import Session @@ -499,3 +501,111 @@ def test_execute_job_provider_factory_failure_marks_job_failed( collection_crud_instance.delete_by_id.assert_not_called() mock_get_llm_provider.assert_called_once() + + +@patch("app.services.collections.delete_collection.get_llm_provider") +def test_execute_job_timeout_marks_job_failed( + mock_get_llm_provider: MagicMock, db +) -> None: + project = get_project(db) + + collection = get_vector_store_collection(db, project, vector_store_id="vs_timeout") + + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) + + mock_provider = MagicMock() + mock_provider.delete.side_effect = Timeout(300) + mock_get_llm_provider.return_value = mock_provider + + with patch( + "app.services.collections.delete_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.delete_collection.CollectionCrud" + ) as MockCollectionCrud: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + MockCollectionCrud.return_value.read_one.return_value = collection + + req = DeletionRequest(collection_id=collection.id) + + with pytest.raises(Timeout): + execute_job( + request=req.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + collection_id=str(collection.id), + task_instance=None, + ) + + failed_job = CollectionJobCrud(db, project.id).read_one(job.id) + assert failed_job.status == CollectionJobStatus.FAILED + assert "soft time limit" in (failed_job.error_message or "") + + MockCollectionCrud.return_value.delete_by_id.assert_not_called() + + +@patch("app.services.collections.delete_collection.get_llm_provider") +@patch("app.services.collections.delete_collection.send_callback") +def test_execute_job_timeout_sends_failure_callback( + mock_send_callback: MagicMock, + mock_get_llm_provider: MagicMock, + db, +) -> None: + project = get_project(db) + callback_url = "https://example.com/collections/delete-timeout" + + collection = get_vector_store_collection( + db, project, vector_store_id="vs_timeout_cb" + ) + + job = get_collection_job( + db, + project, + action_type=CollectionActionType.DELETE, + status=CollectionJobStatus.PENDING, + collection_id=collection.id, + ) + + mock_provider = MagicMock() + mock_provider.delete.side_effect = Timeout(300) + mock_get_llm_provider.return_value = mock_provider + + with patch( + "app.services.collections.delete_collection.Session" + ) as SessionCtor, patch( + "app.services.collections.delete_collection.CollectionCrud" + ) as MockCollectionCrud: + SessionCtor.return_value.__enter__.return_value = db + SessionCtor.return_value.__exit__.return_value = False + + MockCollectionCrud.return_value.read_one.return_value = collection + + req = DeletionRequest(collection_id=collection.id, callback_url=callback_url) + + with pytest.raises(Timeout): + execute_job( + request=req.model_dump(mode="json"), + project_id=project.id, + organization_id=project.organization_id, + task_id=str(uuid4()), + job_id=str(job.id), + collection_id=str(collection.id), + task_instance=None, + ) + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + assert str(cb_url_arg) == callback_url + assert payload_arg["success"] is False + assert "soft time limit" in (payload_arg["error"] or "") + assert payload_arg["data"]["status"] == CollectionJobStatus.FAILED + assert UUID(payload_arg["data"]["job_id"]) == job.id diff --git a/backend/app/tests/services/doctransformer/test_job/test_execute_job_errors.py b/backend/app/tests/services/doctransformer/test_job/test_execute_job_errors.py index 24da19cbf..a8f0bf63e 100644 --- a/backend/app/tests/services/doctransformer/test_job/test_execute_job_errors.py +++ b/backend/app/tests/services/doctransformer/test_job/test_execute_job_errors.py @@ -4,11 +4,13 @@ from unittest.mock import patch import pytest +from gevent import Timeout from moto import mock_aws from sqlmodel import Session from app.crud import DocTransformationJobCrud from app.models import Document, Project, TransformationStatus, DocTransformJobCreate +from app.services.doctransform.job import execute_job from app.tests.services.doctransformer.test_job.utils import ( DocTransformTestBase, MockTestTransformer, @@ -227,3 +229,103 @@ def test_execute_job_database_error_during_completion( db.refresh(job) assert job.status == TransformationStatus.FAILED assert "Database error during document creation" in job.error_message + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_execute_job_timeout_marks_job_failed( + self, + db: Session, + test_document: Tuple[Document, Project], + ) -> None: + """Test that a gevent Timeout marks the job FAILED with a soft-time-limit message.""" + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job_crud = DocTransformationJobCrud(session=db, project_id=project.id) + job = job_crud.create(DocTransformJobCreate(source_document_id=document.id)) + + with patch( + "app.services.doctransform.job.Session" + ) as mock_session_class, patch( + "app.services.doctransform.job.get_cloud_storage" + ) as mock_storage_class, patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ): + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + mock_storage = mock_storage_class.return_value + mock_storage.stream.side_effect = Timeout(300) + + with pytest.raises(Timeout): + execute_job.__wrapped__( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=None, + task_instance=None, + ) + + db.refresh(job) + assert job.status == TransformationStatus.FAILED + assert "soft time limit" in job.error_message + + @mock_aws + @pytest.mark.usefixtures("aws_credentials") + def test_execute_job_timeout_sends_failure_callback( + self, + db: Session, + test_document: Tuple[Document, Project], + ) -> None: + """Test that a gevent Timeout sends a failure callback when callback_url is set.""" + document, project = test_document + aws = self.setup_aws_s3() + self.create_s3_document_content(aws, document) + + job_crud = DocTransformationJobCrud(session=db, project_id=project.id) + job = job_crud.create(DocTransformJobCreate(source_document_id=document.id)) + + callback_url = "https://example.com/doctransform/timeout" + + with patch( + "app.services.doctransform.job.Session" + ) as mock_session_class, patch( + "app.services.doctransform.job.get_cloud_storage" + ) as mock_storage_class, patch( + "app.services.doctransform.registry.TRANSFORMERS", + {"test": MockTestTransformer}, + ), patch( + "app.services.doctransform.job.send_callback" + ) as mock_send_callback: + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + mock_storage = mock_storage_class.return_value + mock_storage.stream.side_effect = Timeout(300) + + with pytest.raises(Timeout): + execute_job.__wrapped__( + project_id=project.id, + job_id=str(job.id), + source_document_id=str(document.id), + transformer_name="test", + target_format="markdown", + task_id=str(uuid4()), + callback_url=callback_url, + task_instance=None, + ) + + db.refresh(job) + assert job.status == TransformationStatus.FAILED + assert "soft time limit" in job.error_message + + mock_send_callback.assert_called_once() + cb_url_arg, payload_arg = mock_send_callback.call_args.args + assert cb_url_arg == callback_url + assert payload_arg["success"] is False + assert "soft time limit" in (payload_arg["error"] or "") diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 8912542f7..c4608aae8 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -2,6 +2,8 @@ from unittest.mock import patch, MagicMock from uuid import UUID, uuid4 +from gevent import Timeout + from fastapi import HTTPException from sqlmodel import Session, select @@ -1163,6 +1165,36 @@ def test_execute_job_continues_when_no_validator_configs_resolved( env["provider"].execute.assert_called_once() mock_guardrails.assert_not_called() + def test_timeout_reraises_and_marks_job_failed( + self, db, job_env, job_for_execution, request_data + ): + env = job_env + env["provider"].execute.side_effect = Timeout(300) + + with pytest.raises(Timeout): + self._execute_job(job_for_execution, db, request_data) + + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.FAILED + + def test_timeout_with_callback_sends_failure_payload( + self, db, job_env, job_for_execution, request_data + ): + env = job_env + request_data["callback_url"] = "https://example.com/callback" + env["provider"].execute.side_effect = Timeout(300) + + with pytest.raises(Timeout): + self._execute_job(job_for_execution, db, request_data) + + env["send_callback"].assert_called_once() + payload = env["send_callback"].call_args.kwargs["data"] + assert payload["success"] is False + assert "soft time limit" in (payload["error"] or "") + + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.FAILED + class TestStartChainJob: """Test cases for the start_chain_job function.""" @@ -1358,6 +1390,77 @@ def test_chain_status_updated_to_failed_on_error(self, chain_request_data): _, kwargs = mock_handle_error.call_args assert kwargs["chain_id"] == chain_id + def test_timeout_reraises_and_calls_handle_job_error(self, chain_request_data): + chain_id = uuid4() + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch("app.services.llm.chain.chain.LLMChain"), + patch( + "app.services.llm.chain.executor.ChainExecutor" + ) as mock_executor_class, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_executor_class.return_value.run.side_effect = Timeout(300) + mock_handle_error.return_value = { + "success": False, + "error": "Chain job exceeded soft time limit", + } + + with pytest.raises(Timeout): + self._execute_chain_job(chain_request_data) + + mock_handle_error.assert_called_once() + _, kwargs = mock_handle_error.call_args + assert kwargs["chain_id"] == chain_id + callback_response = mock_handle_error.call_args.args[2] + assert callback_response.error == "Chain job exceeded soft time limit" + + def test_timeout_with_callback_sends_failure_payload(self, chain_request_data): + chain_id = uuid4() + chain_request_data["callback_url"] = "https://example.com/chain-callback" + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch("app.services.llm.chain.chain.LLMChain"), + patch( + "app.services.llm.chain.executor.ChainExecutor" + ) as mock_executor_class, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_executor_class.return_value.run.side_effect = Timeout(300) + mock_handle_error.return_value = { + "success": False, + "error": "Chain job exceeded soft time limit", + } + + with pytest.raises(Timeout): + self._execute_chain_job(chain_request_data) + + mock_handle_error.assert_called_once() + call_args = mock_handle_error.call_args + assert call_args.args[1] == "https://example.com/chain-callback" + assert call_args.args[2].error == "Chain job exceeded soft time limit" + assert call_args.kwargs["chain_id"] == chain_id + class TestResolveConfigBlob: """Test suite for resolve_config_blob function.""" From 19c60527fbc3ec27c277ef19d6d73ba73a25e685 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 4 May 2026 10:36:00 +0530 Subject: [PATCH 4/9] fixing test case --- backend/app/tests/services/llm/test_jobs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index c4608aae8..dd946799c 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1413,7 +1413,7 @@ def test_timeout_reraises_and_calls_handle_job_error(self, chain_request_data): mock_executor_class.return_value.run.side_effect = Timeout(300) mock_handle_error.return_value = { "success": False, - "error": "Chain job exceeded soft time limit", + "error": "Task exceeded soft time limit", } with pytest.raises(Timeout): @@ -1423,7 +1423,7 @@ def test_timeout_reraises_and_calls_handle_job_error(self, chain_request_data): _, kwargs = mock_handle_error.call_args assert kwargs["chain_id"] == chain_id callback_response = mock_handle_error.call_args.args[2] - assert callback_response.error == "Chain job exceeded soft time limit" + assert callback_response.error == "Task exceeded soft time limit" def test_timeout_with_callback_sends_failure_payload(self, chain_request_data): chain_id = uuid4() @@ -1449,7 +1449,7 @@ def test_timeout_with_callback_sends_failure_payload(self, chain_request_data): mock_executor_class.return_value.run.side_effect = Timeout(300) mock_handle_error.return_value = { "success": False, - "error": "Chain job exceeded soft time limit", + "error": "Task exceeded soft time limit", } with pytest.raises(Timeout): @@ -1458,7 +1458,7 @@ def test_timeout_with_callback_sends_failure_payload(self, chain_request_data): mock_handle_error.assert_called_once() call_args = mock_handle_error.call_args assert call_args.args[1] == "https://example.com/chain-callback" - assert call_args.args[2].error == "Chain job exceeded soft time limit" + assert call_args.args[2].error == "Task exceeded soft time limit" assert call_args.kwargs["chain_id"] == chain_id From fcdc906c2a934fc5d3733a59edc60e05201fc9e6 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 4 May 2026 11:05:00 +0530 Subject: [PATCH 5/9] increasing code coverage --- .../services/doctransformer/test_registry.py | 20 ++++++++++ .../stt_evaluations/test_metric_job.py | 38 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/backend/app/tests/services/doctransformer/test_registry.py b/backend/app/tests/services/doctransformer/test_registry.py index 0b1909f4e..8eeeebcf5 100644 --- a/backend/app/tests/services/doctransformer/test_registry.py +++ b/backend/app/tests/services/doctransformer/test_registry.py @@ -1,4 +1,6 @@ import pytest +from celery.exceptions import SoftTimeLimitExceeded +from gevent import Timeout from app.services.doctransform.registry import ( get_file_format, @@ -95,3 +97,21 @@ def transform(self, input_path, output_path): monkeypatch.setitem(TRANSFORMERS, "fail", FailingTransformer) with pytest.raises(TransformationError): convert_document(input_file, output_file, transformer_name="fail") + + # gevent Timeout propagates without being wrapped + class TimeoutTransformer: + def transform(self, input_path, output_path): + raise Timeout() + + monkeypatch.setitem(TRANSFORMERS, "timeout", TimeoutTransformer) + with pytest.raises(Timeout): + convert_document(input_file, output_file, transformer_name="timeout") + + # Celery SoftTimeLimitExceeded propagates without being wrapped + class SoftLimitTransformer: + def transform(self, input_path, output_path): + raise SoftTimeLimitExceeded() + + monkeypatch.setitem(TRANSFORMERS, "softlimit", SoftLimitTransformer) + with pytest.raises(SoftTimeLimitExceeded): + convert_document(input_file, output_file, transformer_name="softlimit") diff --git a/backend/app/tests/services/stt_evaluations/test_metric_job.py b/backend/app/tests/services/stt_evaluations/test_metric_job.py index 2bc8fef3c..727eda3dd 100644 --- a/backend/app/tests/services/stt_evaluations/test_metric_job.py +++ b/backend/app/tests/services/stt_evaluations/test_metric_job.py @@ -4,6 +4,10 @@ from typing import Any from unittest.mock import MagicMock, patch +import pytest +from celery.exceptions import SoftTimeLimitExceeded +from gevent import Timeout + from app.services.stt_evaluations.metric_job import execute_metric_computation @@ -239,3 +243,37 @@ def test_aggregate_scores_not_stored_when_nothing_scored( mock_update_run.assert_not_called() session.execute.assert_not_called() + + def test_gevent_timeout_marks_run_failed_and_reraises( + self, mock_session_cls, _mock_now, _mock_calc, mock_update_run + ) -> None: + """Test that a gevent Timeout marks the run as failed and re-raises.""" + session = mock_session_cls.return_value.__enter__.return_value + session.exec.side_effect = Timeout() + + with pytest.raises(Timeout): + execute_metric_computation(**BASE_KWARGS) + + mock_update_run.assert_called_once_with( + session=session, + run_id=BASE_KWARGS["run_id"], + status="failed", + error_message="Task exceeded soft time limit", + ) + + def test_soft_time_limit_exceeded_marks_run_failed_and_reraises( + self, mock_session_cls, _mock_now, _mock_calc, mock_update_run + ) -> None: + """Test that SoftTimeLimitExceeded marks the run as failed and re-raises.""" + session = mock_session_cls.return_value.__enter__.return_value + session.exec.side_effect = SoftTimeLimitExceeded() + + with pytest.raises(SoftTimeLimitExceeded): + execute_metric_computation(**BASE_KWARGS) + + mock_update_run.assert_called_once_with( + session=session, + run_id=BASE_KWARGS["run_id"], + status="failed", + error_message="Task exceeded soft time limit", + ) From 49e5e5f5b9b670f83dc910e6e0fce893dda8181b Mon Sep 17 00:00:00 2001 From: nishika26 Date: Mon, 4 May 2026 23:17:04 +0530 Subject: [PATCH 6/9] test cases and coderabbit pr comments --- backend/app/celery/tasks/job_execution.py | 4 +- backend/app/celery/utils.py | 24 ++++- .../services/collections/create_collection.py | 2 +- .../services/collections/delete_collection.py | 2 +- backend/app/services/doctransform/job.py | 2 +- backend/app/services/llm/jobs.py | 22 ++--- .../app/services/stt_evaluations/batch_job.py | 2 +- .../services/stt_evaluations/metric_job.py | 2 +- .../app/services/tts_evaluations/batch_job.py | 2 +- .../batch_result_processing.py | 2 +- backend/app/tests/services/llm/test_jobs.py | 92 +++++++++++++++++++ 11 files changed, 131 insertions(+), 25 deletions(-) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index aede450e2..4888c7f2e 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -157,12 +157,12 @@ def run_create_collection_job( def run_collection_batch_job( self: celery.Task, project_id: int, job_id: str, trace_id: str, **kwargs: Any ) -> None: - from app.services.collections.create_collection import execute_batch_job + from app.services.collections.create_collection import execute_job _set_trace(trace_id) return _run_with_otel_parent( self, - lambda: execute_batch_job( + lambda: execute_job( project_id=project_id, job_id=job_id, task_id=current_task.request.id, diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 5ec358801..535789bb3 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -4,7 +4,8 @@ """ import logging import functools -from typing import Any, Dict +from typing import Any, Dict, ParamSpec, TypeVar +from collections.abc import Callable from celery.result import AsyncResult from gevent import Timeout @@ -14,6 +15,9 @@ logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") + def _enqueue_with_trace_context(task, **kwargs) -> str: """Publish Celery task with explicit trace context headers.""" @@ -215,16 +219,20 @@ def revoke_task(task_id: str, terminate: bool = False) -> bool: return False -def gevent_timeout(seconds, task_name=None): - def decorator(func): +def gevent_timeout( + seconds: float | None, task_name: str | None = None +) -> Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(func: Callable[P, R]) -> Callable[P, R]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: name = task_name or func.__name__ timeout = Timeout(seconds) timeout.start() try: return func(*args, **kwargs) - except Timeout: + except Timeout as err: + if err is not timeout: + raise logger.error(f"[{name}] Timed out after {seconds}s") raise finally: @@ -233,3 +241,9 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +# In gevent mode, Celery's soft and hard time limits fire during task cleanup, +# producing a misleading "Hard time limit exceeded" log. The task has already +# completed at this point (Pool POST fires first). This is a known gevent/Celery +# interaction and is harmless. diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index ba673d6a7..92af7e7d7 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -322,7 +322,7 @@ def execute_job( ) except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", job_id, diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index 68a3cea89..f751075ca 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -248,7 +248,7 @@ def execute_job( ) except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( "[delete_collection.execute_job] Collection Deletion Timed Out | " "{'collection_id': '%s', 'job_id': '%s'}", diff --git a/backend/app/services/doctransform/job.py b/backend/app/services/doctransform/job.py index 4f9fbb9b6..e64fc9a7a 100644 --- a/backend/app/services/doctransform/job.py +++ b/backend/app/services/doctransform/job.py @@ -276,7 +276,7 @@ def execute_job( send_callback(callback_url, success_payload, webhook_secret=webhook_secret) except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( "[doc_transform.execute_job] Timed Out | job_id=%s", job_uuid, diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 39dc84b33..5574262bf 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -191,17 +191,6 @@ def handle_job_error( chain_id: UUID | None = None, ) -> dict: """Handle job failure uniformly — send callback and update DB.""" - if callback_url: - webhook_secret = get_webhook_secret(project_id, organization_id) - with tracer.start_as_current_span("llm.send_callback") as cb_span: - cb_span.set_attribute("callback.url", callback_url) - cb_span.set_attribute("callback.status", "failure") - send_callback( - callback_url=callback_url, - data=callback_response.model_dump(), - webhook_secret=webhook_secret, - ) - with Session(engine) as session: JobCrud(session=session).update( job_id=job_id, @@ -225,6 +214,17 @@ def handle_job_error( exc_info=True, ) + if callback_url: + webhook_secret = get_webhook_secret(project_id, organization_id) + with tracer.start_as_current_span("llm.send_callback") as cb_span: + cb_span.set_attribute("callback.url", callback_url) + cb_span.set_attribute("callback.status", "failure") + send_callback( + callback_url=callback_url, + data=callback_response.model_dump(), + webhook_secret=webhook_secret, + ) + return callback_response.model_dump() diff --git a/backend/app/services/stt_evaluations/batch_job.py b/backend/app/services/stt_evaluations/batch_job.py index 08e09dac1..1a49eb6e5 100644 --- a/backend/app/services/stt_evaluations/batch_job.py +++ b/backend/app/services/stt_evaluations/batch_job.py @@ -100,7 +100,7 @@ def execute_batch_submission( return batch_result except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( f"[execute_batch_submission] STT batch submission timed out | run_id={run_id}" ) diff --git a/backend/app/services/stt_evaluations/metric_job.py b/backend/app/services/stt_evaluations/metric_job.py index effea895d..32f3a9821 100644 --- a/backend/app/services/stt_evaluations/metric_job.py +++ b/backend/app/services/stt_evaluations/metric_job.py @@ -156,7 +156,7 @@ def execute_metric_computation( } except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( f"[execute_metric_computation] STT metric computation timed out | run_id={run_id}" ) diff --git a/backend/app/services/tts_evaluations/batch_job.py b/backend/app/services/tts_evaluations/batch_job.py index 51e060ea3..7c77f4a53 100644 --- a/backend/app/services/tts_evaluations/batch_job.py +++ b/backend/app/services/tts_evaluations/batch_job.py @@ -128,7 +128,7 @@ def execute_batch_submission( return batch_result except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( f"[execute_batch_submission] TTS batch submission timed out | run_id={run_id}" ) diff --git a/backend/app/services/tts_evaluations/batch_result_processing.py b/backend/app/services/tts_evaluations/batch_result_processing.py index 5394f2934..53c90f6f0 100644 --- a/backend/app/services/tts_evaluations/batch_result_processing.py +++ b/backend/app/services/tts_evaluations/batch_result_processing.py @@ -242,7 +242,7 @@ def execute_tts_result_processing( } except (Timeout, SoftTimeLimitExceeded) as err: - timeout_err = TimeoutError(f"Task exceeded soft time limit") + timeout_err = TimeoutError("Task exceeded soft time limit") logger.error( f"[execute_tts_result_processing] TTS result processing timed out | run_id={evaluation_run_id}" ) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index dd946799c..40ef22110 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -221,6 +221,98 @@ def test_handle_job_error_callback_failure_still_updates_job(self, db: Session): assert "Callback service unavailable" in str(exc_info.value) + def test_handle_job_error_with_chain_id_updates_chain_status(self, db: Session): + """Test that chain status is updated to FAILED when chain_id is provided.""" + from app.models.llm.request import ChainStatus + + job = JobCrud(session=db).create( + job_type=JobType.LLM_API, trace_id="test-trace" + ) + db.commit() + + chain_id = uuid4() + callback_response = APIResponse.failure_response(error="Chain failed") + + with ( + patch("app.services.llm.jobs.Session") as mock_session_class, + patch("app.services.llm.jobs.send_callback"), + patch("app.services.llm.jobs.update_llm_chain_status") as mock_update_chain, + ): + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + handle_job_error( + job_id=job.id, + callback_url=None, + callback_response=callback_response, + chain_id=chain_id, + ) + + mock_update_chain.assert_called_once_with( + db, + chain_id=chain_id, + status=ChainStatus.FAILED, + error="Chain failed", + ) + + def test_handle_job_error_without_chain_id_skips_chain_update(self, db: Session): + """Test that update_llm_chain_status is not called when chain_id is None.""" + job = JobCrud(session=db).create( + job_type=JobType.LLM_API, trace_id="test-trace" + ) + db.commit() + + callback_response = APIResponse.failure_response(error="Job failed") + + with ( + patch("app.services.llm.jobs.Session") as mock_session_class, + patch("app.services.llm.jobs.send_callback"), + patch("app.services.llm.jobs.update_llm_chain_status") as mock_update_chain, + ): + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + handle_job_error( + job_id=job.id, + callback_url=None, + callback_response=callback_response, + chain_id=None, + ) + + mock_update_chain.assert_not_called() + + def test_handle_job_error_chain_update_failure_is_swallowed(self, db: Session): + """Test that an exception from update_llm_chain_status doesn't propagate.""" + job = JobCrud(session=db).create( + job_type=JobType.LLM_API, trace_id="test-trace" + ) + db.commit() + + chain_id = uuid4() + callback_response = APIResponse.failure_response(error="Chain failed") + + with ( + patch("app.services.llm.jobs.Session") as mock_session_class, + patch("app.services.llm.jobs.send_callback"), + patch( + "app.services.llm.jobs.update_llm_chain_status", + side_effect=Exception("DB error updating chain"), + ), + ): + mock_session_class.return_value.__enter__.return_value = db + mock_session_class.return_value.__exit__.return_value = None + + # Should not raise — the exception is caught and logged + result = handle_job_error( + job_id=job.id, + callback_url=None, + callback_response=callback_response, + chain_id=chain_id, + ) + + assert result["success"] is False + assert result["error"] == "Chain failed" + class TestExecuteJob: """Test suite for execute_job.""" From cdf034b295fe4bca09d4a330afdd213f42bf12c8 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Tue, 5 May 2026 08:40:17 +0530 Subject: [PATCH 7/9] increasing test coverage --- backend/app/celery/tasks/job_execution.py | 20 ---- backend/app/services/llm/jobs.py | 2 + backend/app/tests/services/llm/test_jobs.py | 106 ++++++++++++++++++++ 3 files changed, 108 insertions(+), 20 deletions(-) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 4888c7f2e..adadf1c9c 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -152,26 +152,6 @@ def run_create_collection_job( ) -@celery_app.task(bind=True, queue="low_priority", priority=1) -@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_collection_batch_job") -def run_collection_batch_job( - self: celery.Task, project_id: int, job_id: str, trace_id: str, **kwargs: Any -) -> None: - from app.services.collections.create_collection import execute_job - - _set_trace(trace_id) - return _run_with_otel_parent( - self, - lambda: execute_job( - project_id=project_id, - job_id=job_id, - task_id=current_task.request.id, - task_instance=self, - **kwargs, - ), - ) - - @celery_app.task(bind=True, queue="low_priority", priority=1) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_delete_collection_job") def run_delete_collection_job( diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 5574262bf..e37cdb4bd 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -730,6 +730,8 @@ def execute_llm_call( error_message = error or "Unknown error occurred" return BlockResult(error=error_message, llm_call_id=llm_call_id) + except (Timeout, SoftTimeLimitExceeded): + raise except Exception as e: logger.error( f"[execute_llm_call] Unexpected error: {e} | job_id={job_id}", diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 40ef22110..c15f28fcb 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -2,6 +2,7 @@ from unittest.mock import patch, MagicMock from uuid import UUID, uuid4 +from celery.exceptions import SoftTimeLimitExceeded from gevent import Timeout from fastapi import HTTPException @@ -1287,6 +1288,36 @@ def test_timeout_with_callback_sends_failure_payload( db.refresh(job_for_execution) assert job_for_execution.status == JobStatus.FAILED + def test_soft_time_limit_reraises_and_marks_job_failed( + self, db, job_env, job_for_execution, request_data + ): + env = job_env + env["provider"].execute.side_effect = SoftTimeLimitExceeded() + + with pytest.raises(SoftTimeLimitExceeded): + self._execute_job(job_for_execution, db, request_data) + + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.FAILED + + def test_soft_time_limit_with_callback_sends_failure_payload( + self, db, job_env, job_for_execution, request_data + ): + env = job_env + request_data["callback_url"] = "https://example.com/callback" + env["provider"].execute.side_effect = SoftTimeLimitExceeded() + + with pytest.raises(SoftTimeLimitExceeded): + self._execute_job(job_for_execution, db, request_data) + + env["send_callback"].assert_called_once() + payload = env["send_callback"].call_args.kwargs["data"] + assert payload["success"] is False + assert "soft time limit" in (payload["error"] or "") + + db.refresh(job_for_execution) + assert job_for_execution.status == JobStatus.FAILED + class TestStartChainJob: """Test cases for the start_chain_job function.""" @@ -1553,6 +1584,81 @@ def test_timeout_with_callback_sends_failure_payload(self, chain_request_data): assert call_args.args[2].error == "Task exceeded soft time limit" assert call_args.kwargs["chain_id"] == chain_id + def test_soft_time_limit_reraises_and_calls_handle_job_error( + self, chain_request_data + ): + chain_id = uuid4() + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch("app.services.llm.chain.chain.LLMChain"), + patch( + "app.services.llm.chain.executor.ChainExecutor" + ) as mock_executor_class, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_executor_class.return_value.run.side_effect = SoftTimeLimitExceeded() + mock_handle_error.return_value = { + "success": False, + "error": "Task exceeded soft time limit", + } + + with pytest.raises(SoftTimeLimitExceeded): + self._execute_chain_job(chain_request_data) + + mock_handle_error.assert_called_once() + _, kwargs = mock_handle_error.call_args + assert kwargs["chain_id"] == chain_id + callback_response = mock_handle_error.call_args.args[2] + assert callback_response.error == "Task exceeded soft time limit" + + def test_soft_time_limit_with_callback_sends_failure_payload( + self, chain_request_data + ): + chain_id = uuid4() + chain_request_data["callback_url"] = "https://example.com/chain-callback" + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch("app.services.llm.chain.chain.LLMChain"), + patch( + "app.services.llm.chain.executor.ChainExecutor" + ) as mock_executor_class, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_executor_class.return_value.run.side_effect = SoftTimeLimitExceeded() + mock_handle_error.return_value = { + "success": False, + "error": "Task exceeded soft time limit", + } + + with pytest.raises(SoftTimeLimitExceeded): + self._execute_chain_job(chain_request_data) + + mock_handle_error.assert_called_once() + call_args = mock_handle_error.call_args + assert call_args.args[1] == "https://example.com/chain-callback" + assert call_args.args[2].error == "Task exceeded soft time limit" + assert call_args.kwargs["chain_id"] == chain_id + class TestResolveConfigBlob: """Test suite for resolve_config_blob function.""" From 8e390df78b2142e6f287605a80b39849facead87 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 6 May 2026 09:35:59 +0530 Subject: [PATCH 8/9] logger warning instead of logger error --- backend/app/services/collections/create_collection.py | 2 +- backend/app/services/collections/delete_collection.py | 2 +- backend/app/services/doctransform/job.py | 2 +- backend/app/services/doctransform/zerox_transformer.py | 2 +- backend/app/services/llm/jobs.py | 4 ++-- backend/app/services/stt_evaluations/batch_job.py | 4 ++-- backend/app/services/stt_evaluations/metric_job.py | 2 +- backend/app/services/tts_evaluations/batch_job.py | 8 ++++---- .../services/tts_evaluations/batch_result_processing.py | 2 +- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/backend/app/services/collections/create_collection.py b/backend/app/services/collections/create_collection.py index 92af7e7d7..a9b787f6b 100644 --- a/backend/app/services/collections/create_collection.py +++ b/backend/app/services/collections/create_collection.py @@ -323,7 +323,7 @@ def execute_job( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( "[create_collection.execute_job] Collection Creation Timed Out | {'collection_job_id': '%s', 'error': '%s'}", job_id, str(timeout_err), diff --git a/backend/app/services/collections/delete_collection.py b/backend/app/services/collections/delete_collection.py index f751075ca..1b420ca58 100644 --- a/backend/app/services/collections/delete_collection.py +++ b/backend/app/services/collections/delete_collection.py @@ -249,7 +249,7 @@ def execute_job( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( "[delete_collection.execute_job] Collection Deletion Timed Out | " "{'collection_id': '%s', 'job_id': '%s'}", str(collection_uuid), diff --git a/backend/app/services/doctransform/job.py b/backend/app/services/doctransform/job.py index e64fc9a7a..b72440b6f 100644 --- a/backend/app/services/doctransform/job.py +++ b/backend/app/services/doctransform/job.py @@ -277,7 +277,7 @@ def execute_job( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( "[doc_transform.execute_job] Timed Out | job_id=%s", job_uuid, ) diff --git a/backend/app/services/doctransform/zerox_transformer.py b/backend/app/services/doctransform/zerox_transformer.py index bbb1263a2..82f12ece7 100644 --- a/backend/app/services/doctransform/zerox_transformer.py +++ b/backend/app/services/doctransform/zerox_transformer.py @@ -38,7 +38,7 @@ def transform(self, input_path: Path, output_path: Path) -> Path: except (Timeout, SoftTimeLimitExceeded): raise except TimeoutError: - logger.error( + logger.warning( f"ZeroxTransformer timed out for {input_path} (model={self.model})" ) raise RuntimeError( diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index e37cdb4bd..27fe4e28b 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -840,7 +840,7 @@ def execute_job( ) except (Timeout, SoftTimeLimitExceeded): - logger.error( + logger.warning( f"[execute_job] LLM job timed out | job_id={job_uuid}, task_id={task_id}" ) callback_response = APIResponse.failure_response( @@ -970,7 +970,7 @@ def execute_chain_job( return executor.run() except (Timeout, SoftTimeLimitExceeded) as err: - logger.error( + logger.warning( f"[execute_chain_job] Chain job timed out | job_id={job_uuid}, task_id={task_id}" ) callback_response = APIResponse.failure_response( diff --git a/backend/app/services/stt_evaluations/batch_job.py b/backend/app/services/stt_evaluations/batch_job.py index 1a49eb6e5..e68ea9731 100644 --- a/backend/app/services/stt_evaluations/batch_job.py +++ b/backend/app/services/stt_evaluations/batch_job.py @@ -57,7 +57,7 @@ def execute_batch_submission( ) if not run: - logger.error( + logger.warning( f"[execute_batch_submission] Run not found | run_id: {run_id}" ) return {"success": False, "error": "Run not found"} @@ -101,7 +101,7 @@ def execute_batch_submission( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( f"[execute_batch_submission] STT batch submission timed out | run_id={run_id}" ) update_stt_run( diff --git a/backend/app/services/stt_evaluations/metric_job.py b/backend/app/services/stt_evaluations/metric_job.py index 32f3a9821..8880da767 100644 --- a/backend/app/services/stt_evaluations/metric_job.py +++ b/backend/app/services/stt_evaluations/metric_job.py @@ -157,7 +157,7 @@ def execute_metric_computation( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( f"[execute_metric_computation] STT metric computation timed out | run_id={run_id}" ) update_stt_run( diff --git a/backend/app/services/tts_evaluations/batch_job.py b/backend/app/services/tts_evaluations/batch_job.py index 7c77f4a53..ad27d8d51 100644 --- a/backend/app/services/tts_evaluations/batch_job.py +++ b/backend/app/services/tts_evaluations/batch_job.py @@ -61,7 +61,7 @@ def execute_batch_submission( ) if not run: - logger.error( + logger.warning( f"[execute_batch_submission] Run not found | run_id: {run_id}" ) return {"success": False, "error": "Run not found"} @@ -74,7 +74,7 @@ def execute_batch_submission( ) if not dataset: - logger.error( + logger.warning( f"[execute_batch_submission] Dataset not found | " f"run_id: {run_id}, dataset_id: {dataset_id}" ) @@ -89,7 +89,7 @@ def execute_batch_submission( sample_texts = get_sample_texts_from_dataset(session, dataset, project_id) if not sample_texts: - logger.error( + logger.warning( f"[execute_batch_submission] No samples found | " f"run_id: {run_id}, dataset_id: {dataset_id}" ) @@ -129,7 +129,7 @@ def execute_batch_submission( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( f"[execute_batch_submission] TTS batch submission timed out | run_id={run_id}" ) update_tts_run( diff --git a/backend/app/services/tts_evaluations/batch_result_processing.py b/backend/app/services/tts_evaluations/batch_result_processing.py index 53c90f6f0..0ff7651e9 100644 --- a/backend/app/services/tts_evaluations/batch_result_processing.py +++ b/backend/app/services/tts_evaluations/batch_result_processing.py @@ -243,7 +243,7 @@ def execute_tts_result_processing( except (Timeout, SoftTimeLimitExceeded) as err: timeout_err = TimeoutError("Task exceeded soft time limit") - logger.error( + logger.warning( f"[execute_tts_result_processing] TTS result processing timed out | run_id={evaluation_run_id}" ) update_tts_run( From cb42eb9fc8df30bfcfadaab5848fc3d9cd001053 Mon Sep 17 00:00:00 2001 From: nishika26 Date: Wed, 6 May 2026 11:14:02 +0530 Subject: [PATCH 9/9] making type hint liberal --- backend/app/celery/utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 535789bb3..288cba7c4 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -4,7 +4,7 @@ """ import logging import functools -from typing import Any, Dict, ParamSpec, TypeVar +from typing import Any, Dict, TypeVar from collections.abc import Callable from celery.result import AsyncResult @@ -15,8 +15,7 @@ logger = logging.getLogger(__name__) -P = ParamSpec("P") -R = TypeVar("R") +F = TypeVar("F", bound=Callable[..., Any]) def _enqueue_with_trace_context(task, **kwargs) -> str: @@ -221,10 +220,10 @@ def revoke_task(task_id: str, terminate: bool = False) -> bool: def gevent_timeout( seconds: float | None, task_name: str | None = None -) -> Callable[[Callable[P, R]], Callable[P, R]]: - def decorator(func: Callable[P, R]) -> Callable[P, R]: +) -> Callable[[F], F]: + def decorator(func: F) -> F: @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + def wrapper(*args: Any, **kwargs: Any) -> Any: name = task_name or func.__name__ timeout = Timeout(seconds) timeout.start() @@ -238,7 +237,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: finally: timeout.cancel() - return wrapper + return wrapper # type: ignore[return-value] return decorator