diff --git a/app/infra/database.py b/app/infra/database.py index 10c15a5..956662b 100644 --- a/app/infra/database.py +++ b/app/infra/database.py @@ -6,7 +6,11 @@ DATABASE_URL = f"postgresql+asyncpg://{settings.POSTGRES_USER}:{settings.POSTGRES_PASSWORD}@{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/{settings.POSTGRES_DB}" -engine = sqlalchemy.ext.asyncio.create_async_engine(DATABASE_URL, echo=settings.debug) +engine = sqlalchemy.ext.asyncio.create_async_engine( + DATABASE_URL, + echo=settings.debug, + pool_pre_ping=True, +) async def get_db() -> AsyncGenerator[sqlalchemy.ext.asyncio.AsyncConnection, None]: diff --git a/app/service/face_match.py b/app/service/face_match.py index 672c17e..346cce0 100644 --- a/app/service/face_match.py +++ b/app/service/face_match.py @@ -46,45 +46,45 @@ async def process_detected_face( created_face_match_id: UUID | None = None matched_user: ClosestUserMatch | None = None + # Writes run inside the caller's transaction; the worker owns commit/rollback. try: - async with self.conn.begin(): - if not await self.Check_photo_exists(job.photo_id): - logger.warning("Photo not found: %s", job.photo_id) - return + if not await self.Check_photo_exists(job.photo_id): + logger.warning("Photo not found: %s", job.photo_id) + return - if await self._match_exists_for_photo(job.photo_id): - logger.info("Photo %s already matched; skipping", job.photo_id) - return + if await self._match_exists_for_photo(job.photo_id): + logger.info("Photo %s already matched; skipping", job.photo_id) + return - matched_user = await self.user_match_service.find_closest_user( - embedding_literal=embedding_literal, - ) - if await self._autoapprove_if_unmatchable(job, matched_user): - return - assert matched_user is not None - - params = photo_face_queries.PhotoFacesEnsureFaceMatchParams( - photo_id=job.photo_id, - face_index=job.face_index, - column_3=embedding_literal, - bbox=bbox_payload, - user_id=matched_user.user_id, - confidence=matched_user.distance, + matched_user = await self.user_match_service.find_closest_user( + embedding_literal=embedding_literal, + ) + if await self._autoapprove_if_unmatchable(job, matched_user): + return + assert matched_user is not None + + params = photo_face_queries.PhotoFacesEnsureFaceMatchParams( + photo_id=job.photo_id, + face_index=job.face_index, + column_3=embedding_literal, + bbox=bbox_payload, + user_id=matched_user.user_id, + confidence=matched_user.distance, + ) + result = await self.photo_face_querier.photo_faces_ensure_face_match(params) + if result is None: + logger.warning("Failed to ensure face match for photo %s", job.photo_id) + return + + if result.face_match_id is None: + logger.info("Match already exists for photo %s; skipping", job.photo_id) + else: + created_face_match_id = result.face_match_id + logger.info( + "Inserted face match %s for photo %s", + created_face_match_id, + job.photo_id, ) - result = await self.photo_face_querier.photo_faces_ensure_face_match(params) - if result is None: - logger.warning("Failed to ensure face match for photo %s", job.photo_id) - return - - if result.face_match_id is None: - logger.info("Match already exists for photo %s; skipping", job.photo_id) - else: - created_face_match_id = result.face_match_id - logger.info( - "Inserted face match %s for photo %s", - created_face_match_id, - job.photo_id, - ) except (DBAPIError, SQLAlchemyError) as exc: logger.warning("DB write failed for photo %s: %s", job.photo_id, exc) return diff --git a/app/worker/audit/main.py b/app/worker/audit/main.py index 95ceae8..8037c26 100644 --- a/app/worker/audit/main.py +++ b/app/worker/audit/main.py @@ -1,7 +1,6 @@ import asyncio import json from typing import Any -import sqlalchemy.ext.asyncio from pydantic import ValidationError from app.core.constant import AUDIT_EVENT_SUBJECT from app.core.logger import logger @@ -18,34 +17,20 @@ async def init_worker() -> None: class AuditDeliveryWorker: - def __init__(self) -> None: - self._conn: sqlalchemy.ext.asyncio.AsyncConnection | None = None - self._audit_service: AuditService | None = None - - async def start(self) -> None: - if self._conn is not None: - return - self._conn = await engine.connect() - self._audit_service = AuditService( - audit_queries.AsyncQuerier(self._conn), - user_queries.AsyncQuerier(self._conn), - ) - - async def stop(self) -> None: - if self._conn is not None: - await self._conn.close() - self._conn = None - self._audit_service = None - async def persist(self, payload: AuditEventMessage) -> None: - if self._audit_service is None: - logger.warning("Audit service is unavailable for %s", payload.event_type) - return - await self._audit_service.record_event( - event_type=payload.event_type, - user_id=payload.user_id, - metadata=payload.metadata, - ) + # Fresh connection and transaction per event. engine.begin() commits on + # success and rolls back on error, with pool_pre_ping revalidating the + # connection on checkout so a Postgres restart recovers automatically. + async with engine.begin() as conn: + service = AuditService( + audit_queries.AsyncQuerier(conn), + user_queries.AsyncQuerier(conn), + ) + await service.record_event( + event_type=payload.event_type, + user_id=payload.user_id, + metadata=payload.metadata, + ) logger.info("Persisted audit %s for %s", payload.event_type, payload.user_id) @@ -87,13 +72,11 @@ async def listen_nats_event(worker: AuditDeliveryWorker) -> None: async def main() -> None: await init_worker() worker = AuditDeliveryWorker() - await worker.start() await NatsClient.connect() try: await listen_nats_event(worker) await asyncio.Event().wait() finally: - await worker.stop() await NatsClient.close() diff --git a/app/worker/notification/invalid_tokens.py b/app/worker/notification/invalid_tokens.py index 02e24c9..6d1c7c4 100644 --- a/app/worker/notification/invalid_tokens.py +++ b/app/worker/notification/invalid_tokens.py @@ -6,6 +6,7 @@ from app.core.constant import RedisKey from app.core.logger import logger +from app.infra.database import engine from app.infra.redis import RedisClient from app.worker.notification.settings import NotifSetting @@ -43,19 +44,22 @@ async def remove(self, tokens: Sequence[str]) -> None: class DeviceInvalidationStore: - def __init__(self, device_querier: device_queries.AsyncQuerier) -> None: - self._device_querier = device_querier - async def mark_invalid(self, tokens: Iterable[str]) -> None: normalized: list[str] = [t for t in tokens if t] if not normalized: return + # One transaction per token. Handlers run concurrently, so each write + # gets its own connection rather than sharing one, and a failure on one + # token does not abort the rest. failed: list[str] = [] for token in normalized: try: - await self._device_querier.mark_device_token_invalid(push_token=token) + async with engine.begin() as conn: + await device_queries.AsyncQuerier(conn).mark_device_token_invalid( + push_token=token + ) except Exception: failed.append(token) logger.exception("Failed to flag device for invalid token %s", token) diff --git a/app/worker/notification/main.py b/app/worker/notification/main.py index 3f4711d..a79d7c2 100644 --- a/app/worker/notification/main.py +++ b/app/worker/notification/main.py @@ -3,8 +3,6 @@ import asyncio from typing import Sequence -from db.generated import devices as device_queries - from app.core.logger import logger from app.worker.notification.firebase import ( NotificationDeliveryError, @@ -18,7 +16,6 @@ from app.worker.notification.notification_queue import NotificationQueue, NotificationQueueEntry from app.worker.notification.rate_limiter import RateLimiter from app.worker.notification.settings import NotifSetting -from app.infra.database import engine from app.infra.redis import RedisClient from app.infra.nats import NatsClient @@ -141,16 +138,13 @@ async def main() -> None: queue = NotificationQueue(settings=NotifSetting) invalid_tokens = InvalidTokenStore(redis) - db_conn = await engine.connect() - device_querier = device_queries.AsyncQuerier(db_conn) - invalid_devices = DeviceInvalidationStore(device_querier) + invalid_devices = DeviceInvalidationStore() try: await run_worker(queue, invalid_tokens, invalid_devices) finally: await redis.close() - await db_conn.close() logger.info("Worker shutdown") diff --git a/app/worker/photo_worker/main.py b/app/worker/photo_worker/main.py index 651b606..ba7e960 100644 --- a/app/worker/photo_worker/main.py +++ b/app/worker/photo_worker/main.py @@ -4,11 +4,11 @@ import json from enum import Enum -from sqlalchemy.exc import DBAPIError, SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncConnection from app.container import Container from app.core.config import settings +from app.deps.ai_deps import get_face_embedding_service from app.core.constant import MINIO_URL_PREFIX from app.core.logger import logger from app.infra.database import engine @@ -60,6 +60,9 @@ async def handle_message(self, data: bytes) -> None: if event is None: return + # The transaction is owned by run_worker, which opens a fresh + # engine.begin() per message. DB errors are allowed to propagate so that + # transaction rolls back cleanly instead of being committed half-applied. job = await self._create_job(event) try: @@ -133,23 +136,16 @@ async def _handle_group_photo(self, event: PhotoProcessEvent, faces: list[Detect embedding_literal = "[" + ", ".join(str(x) for x in face.embedding) + "]" - try: - approval = await self._photo_face_querier.insert_photo_face_with_approval( - InsertPhotoFaceWithApprovalParams( - photo_id=event.photo_id, - face_index=face_index, - column_3=embedding_literal, - face_embedding=worker_settings.similarity_threshold, - bbox=bbox_json, - decision=PhotoApprovalDecision.PENDING.value, - ) - ) - except (DBAPIError, SQLAlchemyError) as exc: - logger.warning( - "DB error inserting face %d for photo %s: %s", - face_index, event.photo_id, exc, + approval = await self._photo_face_querier.insert_photo_face_with_approval( + InsertPhotoFaceWithApprovalParams( + photo_id=event.photo_id, + face_index=face_index, + column_3=embedding_literal, + face_embedding=worker_settings.similarity_threshold, + bbox=bbox_json, + decision=PhotoApprovalDecision.PENDING.value, ) - continue + ) if approval is None: logger.info("No match for face %d in photo %s", face_index, event.photo_id) @@ -182,21 +178,14 @@ async def _handle_group_photo(self, event: PhotoProcessEvent, faces: list[Detect async def _create_job(self, event: PhotoProcessEvent) -> models.ProcessingJob | None: if self._pj_querier is None: return None - try: - return await self._pj_querier.create_processing_job( - photo_id=event.photo_id, job_type="face_detection", - ) - except Exception as exc: - logger.warning("Failed to create processing job for photo %s: %s", event.photo_id, exc) - return None + return await self._pj_querier.create_processing_job( + photo_id=event.photo_id, job_type="face_detection", + ) async def _update_job(self, job: models.ProcessingJob | None, status: str) -> None: if job is None or self._pj_querier is None: return - try: - await self._pj_querier.update_processing_job_status(id=job.id, status=status) - except Exception as exc: - logger.warning("Failed to update processing job: %s", exc) + await self._pj_querier.update_processing_job_status(id=job.id, status=status) @staticmethod async def _publish_audit(event: PhotoProcessEvent, faces_count: int) -> None: @@ -271,46 +260,58 @@ async def run_worker() -> None: minio_root_user=settings.MINIO_ROOT_USER, minio_root_password=settings.MINIO_ROOT_PASSWORD, ) - RedisClient( + RedisClient.init( host=settings.REDIS_HOST, port=settings.REDIS_PORT, password=settings.REDIS_PASSWORD, ) - async with engine.connect() as conn: - container = Container(conn) + # Load the embedding model once; it is a process-wide singleton. + get_face_embedding_service() - single_face_service = SingleFaceMatchService( - conn=conn, - photo_face_querier=container.photo_face_querier, - photo_querier=container.photo_querier, - user_match_service=container.auth_service, - user_notification_service=container.user_notifications_service, - ) - - worker = PhotoWorker( - conn=conn, - face_embedding_service=container.face_embedding_service, - single_face_service=single_face_service, - user_notification_service=container.user_notifications_service, - photo_face_querier=container.photo_face_querier, - photo_querier=container.photo_querier, - processing_job_querier=container.processing_job_querier, - ) - - await NatsClient.js_subscribe( - subject=NatsSubjects.PHOTO_PROCESS, - callback=worker.handle_message, - stream_name=worker_settings.stream_name, - durable_name=worker_settings.durable_name, - ) - - logger.info("PhotoWorker subscribed on %s; waiting for jobs", NatsSubjects.PHOTO_PROCESS.value) + async def handle(data: bytes) -> None: + # Fresh connection and transaction per message. engine.begin() commits on + # success and rolls back on any error, so a failed message can never leave + # a half-applied or aborted transaction behind for the next one. With + # pool_pre_ping the connection is also revalidated on checkout, so a + # Postgres restart is recovered automatically. Errors are logged here so a + # single bad message does not tear down the subscription. try: - await asyncio.Event().wait() - finally: - await _close_minio() - await NatsClient.close() + async with engine.begin() as conn: + container = Container(conn) + single_face_service = SingleFaceMatchService( + conn=conn, + photo_face_querier=container.photo_face_querier, + photo_querier=container.photo_querier, + user_match_service=container.auth_service, + user_notification_service=container.user_notifications_service, + ) + worker = PhotoWorker( + conn=conn, + face_embedding_service=container.face_embedding_service, + single_face_service=single_face_service, + user_notification_service=container.user_notifications_service, + photo_face_querier=container.photo_face_querier, + photo_querier=container.photo_querier, + processing_job_querier=container.processing_job_querier, + ) + await worker.handle_message(data) + except Exception: + logger.exception("Failed to process photo message") + + await NatsClient.js_subscribe( + subject=NatsSubjects.PHOTO_PROCESS, + callback=handle, + stream_name=worker_settings.stream_name, + durable_name=worker_settings.durable_name, + ) + + logger.info("PhotoWorker subscribed on %s; waiting for jobs", NatsSubjects.PHOTO_PROCESS.value) + try: + await asyncio.Event().wait() + finally: + await _close_minio() + await NatsClient.close() async def _close_minio() -> None: