|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import abc |
| 4 | +import typing as tp |
| 5 | + |
| 6 | +from taskiq import AsyncBroker, AsyncResultBackend |
| 7 | + |
| 8 | + |
| 9 | +if tp.TYPE_CHECKING: |
| 10 | + import asyncio |
| 11 | + |
| 12 | + |
| 13 | +_T = tp.TypeVar("_T") |
| 14 | + |
| 15 | + |
| 16 | +class BasePostgresBroker(AsyncBroker, abc.ABC): |
| 17 | + """Base class for Postgres brokers.""" |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres", |
| 22 | + result_backend: AsyncResultBackend[_T] | None = None, |
| 23 | + task_id_generator: tp.Callable[[], str] | None = None, |
| 24 | + channel_name: str = "taskiq", |
| 25 | + table_name: str = "taskiq_messages", |
| 26 | + max_retry_attempts: int = 5, |
| 27 | + read_kwargs: dict[str, tp.Any] | None = None, |
| 28 | + write_kwargs: dict[str, tp.Any] | None = None, |
| 29 | + ) -> None: |
| 30 | + """ |
| 31 | + Construct a new broker. |
| 32 | +
|
| 33 | + Args: |
| 34 | + dsn: connection string to PostgreSQL, or callable returning one. |
| 35 | + result_backend: Custom result backend. |
| 36 | + task_id_generator: Custom task_id generator. |
| 37 | + channel_name: Name of the channel to listen on. |
| 38 | + table_name: Name of the table to store messages. |
| 39 | + max_retry_attempts: Maximum number of message processing attempts. |
| 40 | + read_kwargs: Additional arguments for read connection creation. |
| 41 | + write_kwargs: Additional arguments for write pool creation. |
| 42 | +
|
| 43 | + """ |
| 44 | + super().__init__( |
| 45 | + result_backend=result_backend, |
| 46 | + task_id_generator=task_id_generator, |
| 47 | + ) |
| 48 | + self._dsn: str | tp.Callable[[], str] = dsn |
| 49 | + self.channel_name: str = channel_name |
| 50 | + self.table_name: str = table_name |
| 51 | + self.read_kwargs: dict[str, tp.Any] = read_kwargs or {} |
| 52 | + self.write_kwargs: dict[str, tp.Any] = write_kwargs or {} |
| 53 | + self.max_retry_attempts: int = max_retry_attempts |
| 54 | + self._queue: asyncio.Queue[str] | None = None |
| 55 | + |
| 56 | + @property |
| 57 | + def dsn(self) -> str: |
| 58 | + """ |
| 59 | + Get the DSN string. |
| 60 | +
|
| 61 | + Returns: |
| 62 | + A string with dsn or None if dsn isn't set yet. |
| 63 | +
|
| 64 | + """ |
| 65 | + if callable(self._dsn): |
| 66 | + return self._dsn() |
| 67 | + return self._dsn |
0 commit comments