Skip to content
Open
76 changes: 76 additions & 0 deletions materializationengine/celery_worker.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import datetime
import logging
import os
import signal
import sys
import threading
import time
import warnings
from typing import Any, Callable, Dict

import redis
from celery import signals
from celery.app.builtins import add_backend_cleanup_task
from celery.schedules import crontab
from celery.signals import after_setup_logger
Expand All @@ -22,6 +26,61 @@
celery_logger = get_task_logger(__name__)


_task_execution_count = 0
_shutdown_requested = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Task count is per-child, not per-worker with prefork

The module-level globals _task_execution_count and _shutdown_requested are process-local. With Celery's prefork pool (used in this project with concurrency up to 4), each child process maintains its own independent counter due to fork semantics. This means if worker_autoshutdown_max_tasks is set to 10 and concurrency is 4, each child would need to run 10 tasks individually before triggering shutdown - potentially 40 total tasks before any action, rather than the expected 10. The counting doesn't aggregate across worker children as the configuration name suggests.

Additional Locations (1)

Fix in Cursor Fix in Web



def _request_worker_shutdown(delay_seconds: int, observed_count: int) -> None:
"""Delay and then terminate the worker process."""
# Delay slightly so task result propagation finishes
time.sleep(max(delay_seconds, 0))
celery_logger.info(
"Auto-shutdown: terminating worker PID %s after %s tasks",
os.getpid(),
observed_count,
)
try:
os.kill(os.getpid(), signal.SIGTERM)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: SIGTERM targets child process, not main worker with prefork

The auto-shutdown feature intends to terminate the worker after a configurable number of tasks, but with Celery's prefork pool (which this project uses), the task_postrun signal fires in the child process, not the main worker process. Calling os.kill(os.getpid(), signal.SIGTERM) from that context terminates only the child process, which the main worker will simply replace with a new one. The worker continues running indefinitely, defeating the feature's purpose. To properly shut down the entire worker, the code would need to signal the parent/main process rather than the current child process.

Fix in Cursor Fix in Web

except Exception as exc: # pragma: no cover - best-effort shutdown
celery_logger.error("Failed to terminate worker: %s", exc)


def _auto_shutdown_handler(sender=None, **kwargs):
"""Trigger worker shutdown after configurable task count when enabled."""
if not celery.conf.get("worker_autoshutdown_enabled", False):
return

max_tasks = celery.conf.get("worker_autoshutdown_max_tasks", 1)
if max_tasks <= 0:
return

global _task_execution_count, _shutdown_requested
if _shutdown_requested:
return

_task_execution_count += 1

if _task_execution_count < max_tasks:
return

_shutdown_requested = True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Race condition in global shutdown state without synchronization

The global variables _task_execution_count and _shutdown_requested are accessed and modified without synchronization (e.g., a threading lock). When using Celery with eventlet, gevent, or threads pools where multiple tasks execute concurrently, two tasks could both pass the if _shutdown_requested: check before either sets it to True. This could cause multiple shutdown threads to be spawned and _task_execution_count to be incremented incorrectly. While the worker still shuts down, the logged task count may be wrong and duplicate shutdown attempts occur.

Additional Locations (1)

Fix in Cursor Fix in Web

delay = celery.conf.get("worker_autoshutdown_delay_seconds", 2)
celery_logger.info(
"Auto-shutdown triggered after %s tasks; terminating in %ss",
_task_execution_count,
delay,
)
shutdown_thread = threading.Thread(
target=_request_worker_shutdown,
args=(delay, _task_execution_count),
daemon=True,
)
shutdown_thread.start()


signals.task_postrun.connect(_auto_shutdown_handler, weak=False)


def create_celery(app=None):
celery.conf.broker_url = app.config["CELERY_BROKER_URL"]
celery.conf.result_backend = app.config["CELERY_RESULT_BACKEND"]
Expand All @@ -32,6 +91,23 @@ def create_celery(app=None):
celery.conf.result_backend_transport_options = {
"master_name": app.config["MASTER_NAME"]
}

celery.conf.worker_autoshutdown_enabled = app.config.get(
"CELERY_WORKER_AUTOSHUTDOWN_ENABLED", False
)
celery.conf.worker_autoshutdown_max_tasks = app.config.get(
"CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS", 1
)
celery.conf.worker_autoshutdown_delay_seconds = app.config.get(
"CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS", 2
)

if celery.conf.worker_autoshutdown_enabled:
celery_logger.info(
"Worker auto-shutdown enabled: max_tasks=%s delay=%ss",
celery.conf.worker_autoshutdown_max_tasks,
celery.conf.worker_autoshutdown_delay_seconds,
)
# Configure Celery and related loggers
log_level = app.config["LOGGING_LEVEL"]
celery_logger.setLevel(log_level)
Expand Down
10 changes: 7 additions & 3 deletions materializationengine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ class BaseConfig:
MERGE_TABLES = True
AUTH_SERVICE_NAMESPACE = "datastack"

REDIS_HOST="localhost"
REDIS_PORT=6379
REDIS_PASSWORD=""
CELERY_WORKER_AUTOSHUTDOWN_ENABLED = False
CELERY_WORKER_AUTOSHUTDOWN_MAX_TASKS = 1
CELERY_WORKER_AUTOSHUTDOWN_DELAY_SECONDS = 2

REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_PASSWORD = ""
SESSION_TYPE = "redis"
PERMANENT_SESSION_LIFETIME = timedelta(hours=24)
SESSION_PREFIX = "annotation_upload_"
Expand Down