From b3013c32fd5eea5bbc19a97dbb395d0cf4767c99 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 29 Jan 2026 21:12:25 +0000 Subject: [PATCH 1/4] feat(amgi-redis): add message send manager so it can be used by other servers, and send elsewhere --- packages/amgi-redis/pyproject.toml | 1 + .../amgi-redis/src/amgi_redis/__init__.py | 77 +++++++++++++++---- uv.lock | 2 + 3 files changed, 66 insertions(+), 14 deletions(-) diff --git a/packages/amgi-redis/pyproject.toml b/packages/amgi-redis/pyproject.toml index c0bc15e..1a9bf69 100644 --- a/packages/amgi-redis/pyproject.toml +++ b/packages/amgi-redis/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "amgi-common==0.30.0", "amgi-types==0.30.0", "redis>=7.0.1", + "typing-extensions>=4.15.0; python_full_version<'3.11'", ] entry-points.amgi_server.amgi-redis = "amgi_redis:_run_cli" diff --git a/packages/amgi-redis/src/amgi_redis/__init__.py b/packages/amgi-redis/src/amgi_redis/__init__.py index 8b96949..d9ec3a5 100644 --- a/packages/amgi-redis/src/amgi_redis/__init__.py +++ b/packages/amgi-redis/src/amgi_redis/__init__.py @@ -1,6 +1,11 @@ import asyncio +import sys from asyncio import Task +from collections.abc import Awaitable +from collections.abc import Callable +from types import TracebackType from typing import Any +from typing import AsyncContextManager from amgi_common import Lifespan from amgi_common import server_serve @@ -9,13 +14,27 @@ from amgi_types import AMGISendEvent from amgi_types import MessageReceiveEvent from amgi_types import MessageScope +from amgi_types import MessageSendEvent from redis.asyncio import from_url from redis.asyncio.client import PubSub from redis.asyncio.client import Redis +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self -def run(app: AMGIApplication, *channels: str, url: str = "redis://localhost") -> None: - server = Server(app, *channels, url=url) +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] + + +def run( + app: AMGIApplication, + *channels: str, + url: str = "redis://localhost", + message_send: _MessageSendManagerT | None = None, +) -> None: + server = Server(app, *channels, url=url, message_send=message_send) server_serve(server) @@ -40,45 +59,75 @@ async def __call__(self) -> MessageReceiveEvent: class _Send: + def __init__(self, message_send: _MessageSendT) -> None: + self._message_send = message_send + + async def __call__(self, event: AMGISendEvent) -> None: + if event["type"] == "message.send": + await self._message_send(event) + + +class MessageSend: def __init__(self, redis: Redis) -> None: self._redis = redis - async def __call__(self, message: AMGISendEvent) -> None: - if message["type"] == "message.send": - await self._redis.publish(message["address"], message["payload"]) + async def __aenter__(self) -> Self: + return self + + async def __call__(self, event: MessageSendEvent) -> None: + await self._redis.publish(event["address"], event["payload"]) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._redis.aclose() class Server: - def __init__(self, app: AMGIApplication, *channels: str, url: str): + def __init__( + self, + app: AMGIApplication, + *channels: str, + url: str, + message_send: _MessageSendManagerT | None = None, + ) -> None: self._app = app self._channels = channels - self._url = url + self._redis = from_url(url) + self._message_send = message_send or MessageSend(self._redis) self._stoppable = Stoppable() self._tasks = set[Task[None]]() async def serve(self) -> None: - redis = from_url(self._url) - async with redis.pubsub() as pubsub: + async with self._redis.pubsub() as pubsub, self._message_send as message_send: await pubsub.subscribe(*self._channels) async with Lifespan(self._app) as state: - await self._main_loop(redis, pubsub, state) + await self._main_loop(message_send, pubsub, state) await asyncio.gather(*self._tasks, return_exceptions=True) async def _main_loop( - self, redis: Redis, pubsub: PubSub, state: dict[str, Any] + self, message_send: _MessageSendT, pubsub: PubSub, state: dict[str, Any] ) -> None: loop = asyncio.get_event_loop() async for message in self._stoppable.call( pubsub.get_message, ignore_subscribe_messages=True, timeout=None ): if message is not None: - task = loop.create_task(self._handle_message(message, redis, state)) + task = loop.create_task( + self._handle_message(message, message_send, state) + ) self._tasks.add(task) task.add_done_callback(self._tasks.discard) async def _handle_message( - self, message: dict[str, Any], redis: Redis, state: dict[str, Any] + self, + message: dict[str, Any], + message_send: _MessageSendT, + state: dict[str, Any], ) -> None: scope: MessageScope = { "type": "message", @@ -86,7 +135,7 @@ async def _handle_message( "address": message["channel"].decode(), "state": state.copy(), } - await self._app(scope, _Receive(message), _Send(redis)) + await self._app(scope, _Receive(message), _Send(message_send)) def stop(self) -> None: self._stoppable.stop() diff --git a/uv.lock b/uv.lock index f569004..a187bec 100644 --- a/uv.lock +++ b/uv.lock @@ -453,6 +453,7 @@ dependencies = [ { name = "amgi-common" }, { name = "amgi-types" }, { name = "redis" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] [package.dev-dependencies] @@ -470,6 +471,7 @@ requires-dist = [ { name = "amgi-common", editable = "packages/amgi-common" }, { name = "amgi-types", editable = "packages/amgi-types" }, { name = "redis", specifier = ">=7.0.1" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.15.0" }, ] [package.metadata.requires-dev] From d74208c1d8a28a7bee910170d8118519c709e423 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 29 Jan 2026 21:19:45 +0000 Subject: [PATCH 2/4] feat(amgi-aiobotocore): add message send manager so it can be used by other servers, and send elsewhere --- packages/amgi-aiobotocore/pyproject.toml | 1 + .../src/amgi_aiobotocore/sqs.py | 118 +++++++++++++----- uv.lock | 2 + 3 files changed, 91 insertions(+), 30 deletions(-) diff --git a/packages/amgi-aiobotocore/pyproject.toml b/packages/amgi-aiobotocore/pyproject.toml index fa9df29..80f8b2f 100644 --- a/packages/amgi-aiobotocore/pyproject.toml +++ b/packages/amgi-aiobotocore/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "aiobotocore>=2.25.0", "amgi-common==0.30.0", "amgi-types==0.30.0", + "typing-extensions>=4.15.0; python_full_version<'3.11'", ] entry-points.amgi_server.amgi-aiobotocore-sqs = "amgi_aiobotocore.sqs:_run_cli" diff --git a/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py b/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py index 605defc..7fb85dc 100644 --- a/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py +++ b/packages/amgi-aiobotocore/src/amgi_aiobotocore/sqs.py @@ -1,9 +1,14 @@ import asyncio +import sys from collections import deque +from collections.abc import Awaitable +from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterable from collections.abc import Sequence +from types import TracebackType from typing import Any +from typing import AsyncContextManager from aiobotocore.session import get_session from amgi_common import Lifespan @@ -15,6 +20,15 @@ from amgi_types import AMGISendEvent from amgi_types import MessageReceiveEvent from amgi_types import MessageScope +from amgi_types import MessageSendEvent + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] def run( @@ -24,6 +38,7 @@ def run( endpoint_url: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, + message_send: _MessageSendManagerT | None = None, ) -> None: server = Server( app, @@ -32,6 +47,7 @@ def run( endpoint_url=endpoint_url, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, + message_send=message_send, ) server_serve(server) @@ -84,20 +100,23 @@ async def __call__(self) -> MessageReceiveEvent: } +async def _get_queue_url(client: Any, queue_name: str) -> str: + queue_url_response = await client.get_queue_url(QueueName=queue_name) + queue_url = queue_url_response["QueueUrl"] + assert isinstance(queue_url, str) + return queue_url + + class _QueueUrlCache: def __init__(self, client: Any) -> None: self._client = client - self._operation_cacher = OperationCacher(self._get_queue_url) + self._operation_cacher = OperationCacher[str, str]( + lambda queue_name: _get_queue_url(client, queue_name) + ) async def get_queue_url(self, queue_name: str) -> str: return await self._operation_cacher.get(queue_name) - async def _get_queue_url(self, queue_name: str) -> str: - queue_url_response = await self._client.get_queue_url(QueueName=queue_name) - queue_url = queue_url_response["QueueUrl"] - assert isinstance(queue_url, str) - return queue_url - class SqsBatchFailureError(IOError): def __init__(self, sender_fault: bool, code: str, message: str): @@ -202,18 +221,56 @@ async def send_message( await self._operation_batcher.enqueue((queue_url, payload, headers)) +class MessageSend: + def __init__( + self, + region_name: str | None = None, + endpoint_url: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + ) -> None: + session = get_session() + + self._client_context = session.create_client( + "sqs", + region_name=region_name, + endpoint_url=endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + async def __aenter__(self) -> Self: + self._client = await self._client_context.__aenter__() + self._send_batcher = _SendBatcher(self._client) + self._queue_url_cache = _QueueUrlCache(self._client) + + return self + + async def __call__(self, event: MessageSendEvent) -> None: + queue_url = await self._queue_url_cache.get_queue_url(event["address"]) + await self._send_batcher.send_message( + queue_url, event["payload"], event["headers"] + ) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._client_context.__aexit__(exc_type, exc_val, exc_tb) + + class _Send: def __init__( self, - send_batcher: _SendBatcher, delete_batcher: _DeleteBatcher, - queue_url_cache: _QueueUrlCache, queue_url: str, + message_send: _MessageSendT, ) -> None: self._queue_url = queue_url - self._queue_url_cache = queue_url_cache self._delete_batcher = delete_batcher - self._send_batcher = send_batcher + self._message_send = message_send async def __call__(self, event: AMGISendEvent) -> None: if event["type"] == "message.ack": @@ -222,10 +279,7 @@ async def __call__(self, event: AMGISendEvent) -> None: event["id"], ) if event["type"] == "message.send": - queue_url = await self._queue_url_cache.get_queue_url(event["address"]) - await self._send_batcher.send_message( - queue_url, event["payload"], event["headers"] - ) + await self._message_send(event) class Server: @@ -237,6 +291,7 @@ def __init__( endpoint_url: str | None = None, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, + message_send: _MessageSendManagerT | None = None, ) -> None: self._app = app self._queues = queues @@ -244,25 +299,30 @@ def __init__( self._endpoint_url = endpoint_url self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key + self._message_send = message_send or MessageSend( + region_name, endpoint_url, aws_access_key_id, aws_secret_access_key + ) + self._stoppable = Stoppable() async def serve(self) -> None: session = get_session() - async with session.create_client( - "sqs", - region_name=self._region_name, - endpoint_url=self._endpoint_url, - aws_access_key_id=self._aws_access_key_id, - aws_secret_access_key=self._aws_secret_access_key, - ) as client: - queue_url_cache = _QueueUrlCache(client) + async with ( + session.create_client( + "sqs", + region_name=self._region_name, + endpoint_url=self._endpoint_url, + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + ) as client, + self._message_send as message_send, + ): delete_batcher = _DeleteBatcher(client) - send_batcher = _SendBatcher(client) queue_urls = zip( await asyncio.gather( - *(queue_url_cache.get_queue_url(queue) for queue in self._queues) + *(_get_queue_url(client, queue) for queue in self._queues) ), self._queues, ) @@ -273,9 +333,8 @@ async def serve(self) -> None: client, queue_url, queue, - queue_url_cache, delete_batcher, - send_batcher, + message_send, state, ) for queue_url, queue in queue_urls @@ -287,9 +346,8 @@ async def _queue_loop( client: Any, queue_url: str, queue_name: str, - queue_url_cache: _QueueUrlCache, delete_batcher: _DeleteBatcher, - send_batcher: _SendBatcher, + message_send: _MessageSendT, state: dict[str, Any], ) -> None: async for messages_response in self._stoppable.call( @@ -310,7 +368,7 @@ async def _queue_loop( await self._app( scope, _Receive(messages), - _Send(send_batcher, delete_batcher, queue_url_cache, queue_url), + _Send(delete_batcher, queue_url, message_send), ) def stop(self) -> None: diff --git a/uv.lock b/uv.lock index a187bec..671819d 100644 --- a/uv.lock +++ b/uv.lock @@ -309,6 +309,7 @@ dependencies = [ { name = "aiobotocore" }, { name = "amgi-common" }, { name = "amgi-types" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] [package.dev-dependencies] @@ -326,6 +327,7 @@ requires-dist = [ { name = "aiobotocore", specifier = ">=2.25.0" }, { name = "amgi-common", editable = "packages/amgi-common" }, { name = "amgi-types", editable = "packages/amgi-types" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.15.0" }, ] [package.metadata.requires-dev] From 7bf3b8487ecc541c29529f26f270c82273077f7e Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 29 Jan 2026 21:21:52 +0000 Subject: [PATCH 3/4] feat(amgi-aiokafka): add message send manager so it can be used by other servers, and send elsewhere --- packages/amgi-aiokafka/pyproject.toml | 1 + .../src/amgi_aiokafka/__init__.py | 100 ++++++++++++------ uv.lock | 2 + 3 files changed, 72 insertions(+), 31 deletions(-) diff --git a/packages/amgi-aiokafka/pyproject.toml b/packages/amgi-aiokafka/pyproject.toml index c818605..e27b448 100644 --- a/packages/amgi-aiokafka/pyproject.toml +++ b/packages/amgi-aiokafka/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "aiokafka>=0.12", "amgi-common==0.30.0", "amgi-types==0.30.0", + "typing-extensions>=4.15.0; python_full_version<'3.11'", ] urls.Changelog = "https://github.com/asyncfast/amgi/blob/main/CHANGELOG.md" diff --git a/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py b/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py index e3b60e8..6c6ebbd 100644 --- a/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py +++ b/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py @@ -1,11 +1,14 @@ import asyncio import logging +import sys from asyncio import Lock from collections import deque from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Iterable +from types import TracebackType from typing import Any +from typing import AsyncContextManager from typing import Literal from aiokafka import AIOKafkaConsumer @@ -21,6 +24,15 @@ from amgi_types import MessageScope from amgi_types import MessageSendEvent +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] + + logger = logging.getLogger("amgi-aiokafka.error") @@ -33,6 +45,7 @@ def run( bootstrap_servers: str | list[str] = "localhost", group_id: str | None = None, auto_offset_reset: AutoOffsetReset = "latest", + message_send: _MessageSendManagerT | None = None, ) -> None: server = Server( app, @@ -40,6 +53,7 @@ def run( bootstrap_servers=bootstrap_servers, group_id=group_id, auto_offset_reset=auto_offset_reset, + message_send=message_send, ) server_serve(server) @@ -86,8 +100,8 @@ def __init__( self, consumer: AIOKafkaConsumer, message_receive_ids: dict[str, dict[TopicPartition, int]], - message_send: Callable[[MessageSendEvent], Awaitable[None]], ackable_consumer: bool, + message_send: _MessageSendT, ) -> None: self._consumer = consumer self._message_send = message_send @@ -102,6 +116,48 @@ async def __call__(self, event: AMGISendEvent) -> None: await self._message_send(event) +class MessageSend: + def __init__(self, bootstrap_servers: str | list[str]) -> None: + self._bootstrap_servers = bootstrap_servers + self._producer = None + self._producer_lock = Lock() + + async def __aenter__(self) -> Self: + return self + + async def __call__(self, event: MessageSendEvent) -> None: + producer = await self._get_producer() + encoded_headers = [(key.decode(), value) for key, value in event["headers"]] + + key = event.get("bindings", {}).get("kafka", {}).get("key") + await producer.send( + event["address"], + headers=encoded_headers, + value=event.get("payload"), + key=key, + ) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self._producer is not None: + await self._producer.stop() + + async def _get_producer(self) -> AIOKafkaProducer: + if self._producer is None: + async with self._producer_lock: + if self._producer is None: + producer = AIOKafkaProducer( + bootstrap_servers=self._bootstrap_servers + ) + await producer.start() + self._producer = producer + return self._producer + + class Server: _consumer: AIOKafkaConsumer @@ -112,15 +168,15 @@ def __init__( bootstrap_servers: str | list[str], group_id: str | None, auto_offset_reset: AutoOffsetReset = "latest", + message_send: _MessageSendManagerT | None = None, ) -> None: self._app = app self._topics = topics self._bootstrap_servers = bootstrap_servers self._group_id = group_id self._auto_offset_reset = auto_offset_reset + self._message_send = message_send or MessageSend(bootstrap_servers) self._ackable_consumer = self._group_id is not None - self._producer: AIOKafkaProducer | None = None - self._producer_lock = Lock() self._stoppable = Stoppable() async def serve(self) -> None: @@ -131,20 +187,21 @@ async def serve(self) -> None: enable_auto_commit=False, auto_offset_reset=self._auto_offset_reset, ) - async with self._consumer: + async with self._consumer, self._message_send as message_send: async with Lifespan(self._app) as state: - await self._main_loop(state) - - if self._producer is not None: - await self._producer.stop() + await self._main_loop(state, message_send) - async def _main_loop(self, state: dict[str, Any]) -> None: + async def _main_loop( + self, state: dict[str, Any], message_send: _MessageSendT + ) -> None: async for messages in self._stoppable.call( self._consumer.getmany, timeout_ms=1000 ): await asyncio.gather( *[ - self._handle_partition_records(topic_partition, records, state) + self._handle_partition_records( + topic_partition, records, message_send, state + ) for topic_partition, records in messages.items() ] ) @@ -153,6 +210,7 @@ async def _handle_partition_records( self, topic_partition: TopicPartition, records: list[ConsumerRecord], + message_send: _MessageSendT, state: dict[str, Any], ) -> None: if records: @@ -176,30 +234,10 @@ async def _handle_partition_records( _Send( self._consumer, message_receive_ids, - self._message_send, self._ackable_consumer, + message_send, ), ) - async def _get_producer(self) -> AIOKafkaProducer: - async with self._producer_lock: - if self._producer is None: - producer = AIOKafkaProducer(bootstrap_servers=self._bootstrap_servers) - await producer.start() - self._producer = producer - return self._producer - - async def _message_send(self, event: MessageSendEvent) -> None: - producer = await self._get_producer() - encoded_headers = [(key.decode(), value) for key, value in event["headers"]] - - key = event.get("bindings", {}).get("kafka", {}).get("key") - await producer.send( - event["address"], - headers=encoded_headers, - value=event.get("payload"), - key=key, - ) - def stop(self) -> None: self._stoppable.stop() diff --git a/uv.lock b/uv.lock index 671819d..5fdf934 100644 --- a/uv.lock +++ b/uv.lock @@ -348,6 +348,7 @@ dependencies = [ { name = "aiokafka" }, { name = "amgi-common" }, { name = "amgi-types" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] [package.dev-dependencies] @@ -365,6 +366,7 @@ requires-dist = [ { name = "aiokafka", specifier = ">=0.12" }, { name = "amgi-common", editable = "packages/amgi-common" }, { name = "amgi-types", editable = "packages/amgi-types" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.15.0" }, ] [package.metadata.requires-dev] From 1e8359199df040ac9d4ba6ccf001124c507e7f58 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 29 Jan 2026 21:31:18 +0000 Subject: [PATCH 4/4] test(amgi-aiokafka): use earliest offset with consumers to stabilize kafka tests --- .../tests_amgi_aiokafka/test_kafka_message_integration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/amgi-aiokafka/tests_amgi_aiokafka/test_kafka_message_integration.py b/packages/amgi-aiokafka/tests_amgi_aiokafka/test_kafka_message_integration.py index d2293df..aa7e22f 100644 --- a/packages/amgi-aiokafka/tests_amgi_aiokafka/test_kafka_message_integration.py +++ b/packages/amgi-aiokafka/tests_amgi_aiokafka/test_kafka_message_integration.py @@ -116,7 +116,7 @@ async def test_message_send( await producer.send_and_wait(receive_topic, b"") async with AIOKafkaConsumer( - send_topic, bootstrap_servers=bootstrap_server + send_topic, bootstrap_servers=bootstrap_server, auto_offset_reset="earliest" ) as consumer: async with app.call() as (scope, receive, send): await send( @@ -146,7 +146,7 @@ async def test_message_send_kafka_key( await producer.send_and_wait(receive_topic, b"") async with AIOKafkaConsumer( - send_topic, bootstrap_servers=bootstrap_server + send_topic, bootstrap_servers=bootstrap_server, auto_offset_reset="earliest" ) as consumer: async with app.call() as (scope, receive, send): await send(