Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion app/infra/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
DATABASE_URL = f"postgresql+asyncpg://{settings.POSTGRES_USER}:{settings.POSTGRES_PASSWORD}@{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/{settings.POSTGRES_DB}"


engine = sqlalchemy.ext.asyncio.create_async_engine(DATABASE_URL, echo=settings.debug)
engine = sqlalchemy.ext.asyncio.create_async_engine(
DATABASE_URL,
echo=settings.debug,
pool_pre_ping=True,
)


async def get_db() -> AsyncGenerator[sqlalchemy.ext.asyncio.AsyncConnection, None]:
Expand Down
70 changes: 35 additions & 35 deletions app/service/face_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,45 @@ async def process_detected_face(
created_face_match_id: UUID | None = None
matched_user: ClosestUserMatch | None = None

# Writes run inside the caller's transaction; the worker owns commit/rollback.
try:
async with self.conn.begin():
if not await self.Check_photo_exists(job.photo_id):
logger.warning("Photo not found: %s", job.photo_id)
return
if not await self.Check_photo_exists(job.photo_id):
logger.warning("Photo not found: %s", job.photo_id)
return

if await self._match_exists_for_photo(job.photo_id):
logger.info("Photo %s already matched; skipping", job.photo_id)
return
if await self._match_exists_for_photo(job.photo_id):
logger.info("Photo %s already matched; skipping", job.photo_id)
return

matched_user = await self.user_match_service.find_closest_user(
embedding_literal=embedding_literal,
)
if await self._autoapprove_if_unmatchable(job, matched_user):
return
assert matched_user is not None

params = photo_face_queries.PhotoFacesEnsureFaceMatchParams(
photo_id=job.photo_id,
face_index=job.face_index,
column_3=embedding_literal,
bbox=bbox_payload,
user_id=matched_user.user_id,
confidence=matched_user.distance,
matched_user = await self.user_match_service.find_closest_user(
embedding_literal=embedding_literal,
)
if await self._autoapprove_if_unmatchable(job, matched_user):
return
assert matched_user is not None

params = photo_face_queries.PhotoFacesEnsureFaceMatchParams(
photo_id=job.photo_id,
face_index=job.face_index,
column_3=embedding_literal,
bbox=bbox_payload,
user_id=matched_user.user_id,
confidence=matched_user.distance,
)
result = await self.photo_face_querier.photo_faces_ensure_face_match(params)
if result is None:
logger.warning("Failed to ensure face match for photo %s", job.photo_id)
return

if result.face_match_id is None:
logger.info("Match already exists for photo %s; skipping", job.photo_id)
else:
created_face_match_id = result.face_match_id
logger.info(
"Inserted face match %s for photo %s",
created_face_match_id,
job.photo_id,
)
result = await self.photo_face_querier.photo_faces_ensure_face_match(params)
if result is None:
logger.warning("Failed to ensure face match for photo %s", job.photo_id)
return

if result.face_match_id is None:
logger.info("Match already exists for photo %s; skipping", job.photo_id)
else:
created_face_match_id = result.face_match_id
logger.info(
"Inserted face match %s for photo %s",
created_face_match_id,
job.photo_id,
)
except (DBAPIError, SQLAlchemyError) as exc:
logger.warning("DB write failed for photo %s: %s", job.photo_id, exc)
return
Expand Down
43 changes: 13 additions & 30 deletions app/worker/audit/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
from typing import Any
import sqlalchemy.ext.asyncio
from pydantic import ValidationError
from app.core.constant import AUDIT_EVENT_SUBJECT
from app.core.logger import logger
Expand All @@ -18,34 +17,20 @@ async def init_worker() -> None:


class AuditDeliveryWorker:
def __init__(self) -> None:
self._conn: sqlalchemy.ext.asyncio.AsyncConnection | None = None
self._audit_service: AuditService | None = None

async def start(self) -> None:
if self._conn is not None:
return
self._conn = await engine.connect()
self._audit_service = AuditService(
audit_queries.AsyncQuerier(self._conn),
user_queries.AsyncQuerier(self._conn),
)

async def stop(self) -> None:
if self._conn is not None:
await self._conn.close()
self._conn = None
self._audit_service = None

async def persist(self, payload: AuditEventMessage) -> None:
if self._audit_service is None:
logger.warning("Audit service is unavailable for %s", payload.event_type)
return
await self._audit_service.record_event(
event_type=payload.event_type,
user_id=payload.user_id,
metadata=payload.metadata,
)
# Fresh connection and transaction per event. engine.begin() commits on
# success and rolls back on error, with pool_pre_ping revalidating the
# connection on checkout so a Postgres restart recovers automatically.
async with engine.begin() as conn:
service = AuditService(
audit_queries.AsyncQuerier(conn),
user_queries.AsyncQuerier(conn),
)
await service.record_event(
event_type=payload.event_type,
user_id=payload.user_id,
metadata=payload.metadata,
)
logger.info("Persisted audit %s for %s", payload.event_type, payload.user_id)


Expand Down Expand Up @@ -87,13 +72,11 @@ async def listen_nats_event(worker: AuditDeliveryWorker) -> None:
async def main() -> None:
await init_worker()
worker = AuditDeliveryWorker()
await worker.start()
await NatsClient.connect()
try:
await listen_nats_event(worker)
await asyncio.Event().wait()
finally:
await worker.stop()
await NatsClient.close()


Expand Down
12 changes: 8 additions & 4 deletions app/worker/notification/invalid_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from app.core.constant import RedisKey
from app.core.logger import logger
from app.infra.database import engine
from app.infra.redis import RedisClient
from app.worker.notification.settings import NotifSetting

Expand Down Expand Up @@ -43,19 +44,22 @@ async def remove(self, tokens: Sequence[str]) -> None:


class DeviceInvalidationStore:
def __init__(self, device_querier: device_queries.AsyncQuerier) -> None:
self._device_querier = device_querier

async def mark_invalid(self, tokens: Iterable[str]) -> None:
normalized: list[str] = [t for t in tokens if t]

if not normalized:
return

# One transaction per token. Handlers run concurrently, so each write
# gets its own connection rather than sharing one, and a failure on one
# token does not abort the rest.
failed: list[str] = []
for token in normalized:
try:
await self._device_querier.mark_device_token_invalid(push_token=token)
async with engine.begin() as conn:
await device_queries.AsyncQuerier(conn).mark_device_token_invalid(
push_token=token
)
except Exception:
failed.append(token)
logger.exception("Failed to flag device for invalid token %s", token)
Expand Down
8 changes: 1 addition & 7 deletions app/worker/notification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import asyncio
from typing import Sequence

from db.generated import devices as device_queries

from app.core.logger import logger
from app.worker.notification.firebase import (
NotificationDeliveryError,
Expand All @@ -18,7 +16,6 @@
from app.worker.notification.notification_queue import NotificationQueue, NotificationQueueEntry
from app.worker.notification.rate_limiter import RateLimiter
from app.worker.notification.settings import NotifSetting
from app.infra.database import engine
from app.infra.redis import RedisClient
from app.infra.nats import NatsClient

Expand Down Expand Up @@ -141,16 +138,13 @@ async def main() -> None:

queue = NotificationQueue(settings=NotifSetting)
invalid_tokens = InvalidTokenStore(redis)
db_conn = await engine.connect()
device_querier = device_queries.AsyncQuerier(db_conn)
invalid_devices = DeviceInvalidationStore(device_querier)
invalid_devices = DeviceInvalidationStore()

try:
await run_worker(queue, invalid_tokens, invalid_devices)

finally:
await redis.close()
await db_conn.close()
logger.info("Worker shutdown")


Expand Down
Loading