From cc7ffaa26a930e62c82b9776ce908d1612e9c879 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sun, 21 Dec 2025 17:23:59 -0800 Subject: [PATCH 1/4] chore: Organize test fixtures Rewrite the test fixtures to have a more clear split between local and mqtt fixtures. This is in prepration for running both at once for e2 tests for device manager. All fixtures are moved into a fixtures subdirectory. The helper classes that are imported into other tests are added in separate files for importing, and to avoid import warnings from pytests. This renames all the fixtures to have mqtt prefixed names and local fixtures to have local prefixed names. There is one minor change to make the local asyncio tests uses asyncio Queues rather than blocking queues. --- tests/conftest.py | 532 ------------------ tests/devices/test_a01_channel.py | 3 +- tests/devices/test_v1_channel.py | 5 +- tests/devices/traits/a01/test_init.py | 2 +- tests/devices/traits/b01/test_init.py | 2 +- tests/e2e/__init__.py | 7 + tests/e2e/test_local_session.py | 64 +-- tests/e2e/test_mqtt_session.py | 25 +- tests/fixtures/__init__.py | 0 .../aiomqtt_fixtures.py} | 43 +- tests/fixtures/channel_fixtures.py | 53 ++ tests/fixtures/local_async_fixtures.py | 77 +++ tests/fixtures/logging.py | 60 ++ tests/fixtures/logging_fixtures.py | 126 +++++ tests/fixtures/mqtt.py | 101 ++++ tests/fixtures/pahomqtt_fixtures.py | 97 ++++ tests/fixtures/web_api_fixtures.py | 141 +++++ tests/mqtt/test_roborock_session.py | 122 ++-- tests/test_a01_api.py | 66 ++- tests/test_api.py | 66 ++- tests/test_local_api_v1.py | 137 ++++- tests/test_web_api.py | 12 + 22 files changed, 989 insertions(+), 752 deletions(-) delete mode 100644 tests/conftest.py create mode 100644 tests/fixtures/__init__.py rename tests/{mqtt_fixtures.py => fixtures/aiomqtt_fixtures.py} (67%) create mode 100644 tests/fixtures/channel_fixtures.py create mode 100644 tests/fixtures/local_async_fixtures.py create mode 100644 tests/fixtures/logging.py create mode 100644 tests/fixtures/logging_fixtures.py create mode 100644 tests/fixtures/mqtt.py create mode 100644 tests/fixtures/pahomqtt_fixtures.py create mode 100644 tests/fixtures/web_api_fixtures.py diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index bbdc2fae..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,532 +0,0 @@ -import asyncio -import io -import logging -import re -from asyncio import Protocol -from collections.abc import AsyncGenerator, Callable, Generator -from queue import Queue -from typing import Any -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -import pytest -from aioresponses import aioresponses - -from roborock import HomeData, UserData -from roborock.data import DeviceData -from roborock.mqtt.health_manager import HealthManager -from roborock.protocols.v1_protocol import LocalProtocolVersion -from roborock.roborock_message import RoborockMessage -from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 -from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 -from tests.mock_data import HOME_DATA_RAW, HOME_DATA_SCENES_RAW, TEST_LOCAL_API_HOST, USER_DATA - -# Fixtures for the newer APIs in subdirectories -pytest_plugins = [ - "tests.mqtt_fixtures", -] - -_LOGGER = logging.getLogger(__name__) - - -# Used by fixtures to handle incoming requests and prepare responses -RequestHandler = Callable[[bytes], bytes | None] -QUEUE_TIMEOUT = 10 - -# Fixed timestamp for deterministic tests for asserting on message contents -FAKE_TIMESTAMP = 1755750946.721395 - - -class CapturedRequestLog: - """Log of requests and responses for snapshot assertions. - - The log captures the raw bytes of each request and response along with - a label indicating the direction of the message. - """ - - def __init__(self) -> None: - """Initialize the request log.""" - self.entries: list[tuple[str, bytes]] = [] - - def add_log_entry(self, label: str, data: bytes) -> None: - """Add a request entry.""" - self.entries.append((label, data)) - - def __repr__(self): - """Return a string representation of the log entries. - - This assumes that the client will behave in a request-response manner, - so each request is followed by a response. If a test uses non-deterministic - message order, this may not be accurate and the test would need to decode - the raw messages and remove any ordering assumptions. - """ - lines = [] - for label, data in self.entries: - lines.append(label) - lines.extend(self._hexdump(data)) - return "\n".join(lines) - - def _hexdump(self, data: bytes, bytes_per_line: int = 16) -> list[str]: - """Print a hexdump of the given bytes object in a tcpdump/hexdump -C style. - - This makes the packets easier to read and compare in test snapshots. - - Args: - data: The bytes object to print. - bytes_per_line: The number of bytes to display per line (default is 16). - """ - - # Use '.' for non-printable characters (ASCII < 32 or > 126) - def to_printable_ascii(byte_val): - return chr(byte_val) if 32 <= byte_val <= 126 else "." - - offset = 0 - lines = [] - while offset < len(data): - chunk = data[offset : offset + bytes_per_line] - # Format the hex values, space-padded to ensure alignment - hex_values = " ".join(f"{byte:02x}" for byte in chunk) - # Pad hex string to a fixed width so ASCII column lines up - # 3 chars per byte ('xx ') for a full line of 16 bytes - padded_hex = f"{hex_values:<{bytes_per_line * 3}}" - # Format the ASCII values - ascii_values = "".join(to_printable_ascii(byte) for byte in chunk) - lines.append(f"{offset:08x} {padded_hex} |{ascii_values}|") - offset += bytes_per_line - return lines - - -@pytest.fixture -def deterministic_message_fixtures() -> Generator[None, None, None]: - """Fixture to use predictable get_next_int and timestamp values for each test. - - This test mocks out the functions used to generate requests that have some - entropy such as the nonces, timestamps, and request IDs. This makes the - generated messages deterministic so we can snapshot them in a test. - """ - - # Pick an arbitrary sequence number used for outgoing requests - next_int = 9090 - - def get_next_int(min_value: int, max_value: int) -> int: - nonlocal next_int - result = next_int - next_int += 1 - if next_int > max_value: - next_int = min_value - return result - - # Pick an arbitrary timestamp used for the message encryption - timestamp = FAKE_TIMESTAMP - - def get_timestamp() -> int: - """Get a monotonically increasing timestamp for testing.""" - nonlocal timestamp - timestamp += 1 - return int(timestamp) - - # Use predictable seeds for token_bytes - token_chr = "A" - - def get_token_bytes(n: int) -> bytes: - nonlocal token_chr - result = token_chr.encode() * n - # Cycle to the next character - token_chr = chr(ord(token_chr) + 1) - if token_chr > "Z": - token_chr = "A" - return result - - with ( - patch("roborock.api.get_next_int", side_effect=get_next_int), - patch("roborock.devices.local_channel.get_next_int", side_effect=get_next_int), - patch("roborock.protocols.v1_protocol.get_next_int", side_effect=get_next_int), - patch("roborock.protocols.v1_protocol.get_timestamp", side_effect=get_timestamp), - patch("roborock.protocols.v1_protocol.secrets.token_bytes", side_effect=get_token_bytes), - patch("roborock.version_1_apis.roborock_local_client_v1.get_next_int", side_effect=get_next_int), - patch("roborock.roborock_message.get_next_int", side_effect=get_next_int), - patch("roborock.roborock_message.get_timestamp", side_effect=get_timestamp), - ): - yield - - -@pytest.fixture(name="log") -def log_fixture(deterministic_message_fixtures: None) -> CapturedRequestLog: - """Fixture that creates a captured request log.""" - return CapturedRequestLog() - - -class FakeSocketHandler: - """Fake socket used by the test to simulate a connection to the broker. - - The socket handler is used to intercept the socket send and recv calls and - populate the response buffer with data to be sent back to the client. The - handle request callback handles the incoming requests and prepares the responses. - """ - - def __init__(self, handle_request: RequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog) -> None: - self.response_buf = io.BytesIO() - self.handle_request = handle_request - self.response_queue = response_queue - self.log = log - - def pending(self) -> int: - """Return the number of bytes in the response buffer.""" - return len(self.response_buf.getvalue()) - - def handle_socket_recv(self, read_size: int) -> bytes: - """Intercept a client recv() and populate the buffer.""" - if self.pending() == 0: - raise BlockingIOError("No response queued") - - self.response_buf.seek(0) - data = self.response_buf.read(read_size) - _LOGGER.debug("Response: 0x%s", data.hex()) - # Consume the rest of the data in the buffer - remaining_data = self.response_buf.read() - self.response_buf = io.BytesIO(remaining_data) - return data - - def handle_socket_send(self, client_request: bytes) -> int: - """Receive an incoming request from the client.""" - _LOGGER.debug("Request: 0x%s", client_request.hex()) - self.log.add_log_entry("[mqtt >]", client_request) - if (response := self.handle_request(client_request)) is not None: - # Enqueue a response to be sent back to the client in the buffer. - # The buffer will be emptied when the client calls recv() on the socket - _LOGGER.debug("Queued: 0x%s", response.hex()) - self.log.add_log_entry("[mqtt <]", response) - self.response_buf.write(response) - return len(client_request) - - def push_response(self) -> None: - """Push a response to the client.""" - if not self.response_queue.empty(): - response = self.response_queue.get() - # Enqueue a response to be sent back to the client in the buffer. - # The buffer will be emptied when the client calls recv() on the socket - _LOGGER.debug("Queued: 0x%s", response.hex()) - self.response_buf.write(response) - - -@pytest.fixture(name="received_requests") -def received_requests_fixture() -> Queue[bytes]: - """Fixture that provides access to the received requests.""" - return Queue() - - -@pytest.fixture(name="response_queue") -def response_queue_fixture() -> Generator[Queue[bytes], None, None]: - """Fixture that provides access to the received requests.""" - response_queue: Queue[bytes] = Queue() - yield response_queue - assert response_queue.empty(), "Not all fake responses were consumed" - - -@pytest.fixture(name="request_handler") -def request_handler_fixture(received_requests: Queue[bytes], response_queue: Queue[bytes]) -> RequestHandler: - """Fixture records incoming requests and replies with responses from the queue.""" - - def handle_request(client_request: bytes) -> bytes | None: - """Handle an incoming request from the client.""" - received_requests.put(client_request) - - # Insert a prepared response into the response buffer - if not response_queue.empty(): - return response_queue.get() - return None - - return handle_request - - -@pytest.fixture(name="fake_socket_handler") -def fake_socket_handler_fixture( - request_handler: RequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog -) -> FakeSocketHandler: - """Fixture that creates a fake MQTT broker.""" - return FakeSocketHandler(request_handler, response_queue, log) - - -@pytest.fixture(name="mock_sock") -def mock_sock_fixture(fake_socket_handler: FakeSocketHandler) -> Mock: - """Fixture that creates a mock socket connection and wires it to the handler.""" - mock_sock = Mock() - mock_sock.recv = fake_socket_handler.handle_socket_recv - mock_sock.send = fake_socket_handler.handle_socket_send - mock_sock.pending = fake_socket_handler.pending - return mock_sock - - -@pytest.fixture(name="mock_create_connection") -def create_connection_fixture(mock_sock: Mock) -> Generator[None, None, None]: - """Fixture that overrides the MQTT socket creation to wire it up to the mock socket.""" - with patch("paho.mqtt.client.socket.create_connection", return_value=mock_sock): - yield - - -@pytest.fixture(name="mock_select") -def select_fixture(mock_sock: Mock, fake_socket_handler: FakeSocketHandler) -> Generator[None, None, None]: - """Fixture that overrides the MQTT client select calls to make select work on the mock socket. - - This patch select to activate our mock socket when ready with data. Internal mqtt sockets are - always ready since they are used internally to wake the select loop. Ours is ready if there - is data in the buffer. - """ - - def is_ready(sock: Any) -> bool: - return sock is not mock_sock or (fake_socket_handler.pending() > 0) - - def handle_select(rlist: list, wlist: list, *args: Any) -> list: - return [list(filter(is_ready, rlist)), list(filter(is_ready, wlist))] - - with patch("paho.mqtt.client.select.select", side_effect=handle_select): - yield - - -@pytest.fixture(name="mqtt_client") -async def mqtt_client(mock_create_connection: None, mock_select: None) -> AsyncGenerator[RoborockMqttClientV1, None]: - user_data = UserData.from_dict(USER_DATA) - home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = DeviceData( - device=home_data.devices[0], - model=home_data.products[0].model, - ) - client = RoborockMqttClientV1(user_data, device_info, queue_timeout=QUEUE_TIMEOUT) - try: - yield client - finally: - if not client.is_connected(): - try: - await client.async_release() - except Exception: - pass - - -@pytest.fixture(name="mock_rest", autouse=True) -def mock_rest() -> aioresponses: - """Mock all rest endpoints so they won't hit real endpoints""" - with aioresponses() as mocked: - # Match the base URL and allow any query params - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"), - status=200, - payload={ - "code": 200, - "data": {"country": "US", "countrycode": "1", "url": "https://usiot.roborock.com"}, - "msg": "success", - }, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/login.*"), - status=200, - payload={"code": 200, "data": USER_DATA, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/loginWithCode.*"), - status=200, - payload={"code": 200, "data": USER_DATA, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/sendEmailCode.*"), - status=200, - payload={"code": 200, "data": None, "msg": "success"}, - ) - mocked.get( - re.compile(r"https://.*iot\.roborock\.com/api/v1/getHomeDetail.*"), - status=200, - payload={ - "code": 200, - "data": {"deviceListOrder": None, "id": 123456, "name": "My Home", "rrHomeId": 123456, "tuyaHomeId": 0}, - "msg": "success", - }, - ) - mocked.get( - re.compile(r"https://api-.*\.roborock\.com/v2/user/homes*"), - status=200, - payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True}, - ) - mocked.post( - re.compile(r"https://api-.*\.roborock\.com/nc/prepare"), - status=200, - payload={ - "api": None, - "result": {"r": "US", "s": "ffffff", "t": "eOf6d2BBBB"}, - "status": "ok", - "success": True, - }, - ) - - mocked.get( - re.compile(r"https://api-.*\.roborock\.com/user/devices/newadd/*"), - status=200, - payload={ - "api": "获取新增设备信息", - "result": { - "activeTime": 1737724598, - "attribute": None, - "cid": None, - "createTime": 0, - "deviceStatus": None, - "duid": "rand_duid", - "extra": "{}", - "f": False, - "featureSet": "0", - "fv": "02.16.12", - "iconUrl": "", - "lat": None, - "localKey": "random_lk", - "lon": None, - "name": "S7", - "newFeatureSet": "0000000000002000", - "online": True, - "productId": "rand_prod_id", - "pv": "1.0", - "roomId": None, - "runtimeEnv": None, - "setting": None, - "share": False, - "shareTime": None, - "silentOtaSwitch": False, - "sn": "Rand_sn", - "timeZoneId": "America/New_York", - "tuyaMigrated": False, - "tuyaUuid": None, - }, - "status": "ok", - "success": True, - }, - ) - mocked.get( - re.compile(r"https://api-.*\.roborock\.com/user/scene/device/.*"), - status=200, - payload={"api": None, "code": 200, "result": HOME_DATA_SCENES_RAW, "status": "ok", "success": True}, - ) - mocked.post( - re.compile(r"https://api-.*\.roborock\.com/user/scene/.*/execute"), - status=200, - payload={"api": None, "code": 200, "result": None, "status": "ok", "success": True}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v4/email/code/send.*"), - status=200, - payload={"code": 200, "data": None, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v3/key/sign.*"), - status=200, - payload={"code": 200, "data": {"k": "mock_k"}, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v4/auth/email/login/code.*"), - status=200, - payload={"code": 200, "data": USER_DATA, "msg": "success"}, - ) - yield mocked - - -@pytest.fixture(autouse=True) -def skip_rate_limit(): - """Don't rate limit tests as they aren't actually hitting the api.""" - with ( - patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire"), - patch("roborock.web_api.RoborockApiClient._home_data_limiter.try_acquire"), - ): - yield - - -@pytest.fixture(name="mock_create_local_connection") -def create_local_connection_fixture( - request_handler: RequestHandler, log: CapturedRequestLog -) -> Generator[None, None, None]: - """Fixture that overrides the transport creation to wire it up to the mock socket.""" - - async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]: - protocol = protocol_factory() - - def handle_write(data: bytes) -> None: - _LOGGER.debug("Received: %s", data) - response = request_handler(data) - log.add_log_entry("[local >]", data) - if response is not None: - _LOGGER.debug("Replying with %s", response) - log.add_log_entry("[local <]", response) - loop = asyncio.get_running_loop() - loop.call_soon(protocol.data_received, response) - - closed = asyncio.Event() - - mock_transport = Mock() - mock_transport.write = handle_write - mock_transport.close = closed.set - mock_transport.is_reading = lambda: not closed.is_set() - - return (mock_transport, "proto") - - with patch("roborock.version_1_apis.roborock_local_client_v1.get_running_loop") as mock_loop: - mock_loop.return_value.create_connection.side_effect = create_connection - yield - - -@pytest.fixture(name="local_client") -async def local_client_fixture(mock_create_local_connection: None) -> AsyncGenerator[RoborockLocalClientV1, None]: - home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = DeviceData( - device=home_data.devices[0], - model=home_data.products[0].model, - host=TEST_LOCAL_API_HOST, - ) - client = RoborockLocalClientV1(device_info, queue_timeout=QUEUE_TIMEOUT) - try: - yield client - finally: - if not client.is_connected(): - try: - await client.async_release() - except Exception: - pass - - -class FakeChannel: - """A fake channel that handles publish and subscribe calls.""" - - def __init__(self): - """Initialize the fake channel.""" - self.subscribers: list[Callable[[RoborockMessage], None]] = [] - self.published_messages: list[RoborockMessage] = [] - self.response_queue: list[RoborockMessage] = [] - self._is_connected = False - self.publish_side_effect: Exception | None = None - self.publish = AsyncMock(side_effect=self._publish) - self.subscribe = AsyncMock(side_effect=self._subscribe) - self.connect = AsyncMock(side_effect=self._connect) - self.close = MagicMock(side_effect=self._close) - self.protocol_version = LocalProtocolVersion.V1 - self.restart = AsyncMock() - self.health_manager = HealthManager(self.restart) - - async def _connect(self) -> None: - self._is_connected = True - - def _close(self) -> None: - self._is_connected = False - - @property - def is_connected(self) -> bool: - """Return true if connected.""" - return self._is_connected - - async def _publish(self, message: RoborockMessage) -> None: - """Simulate publishing a message and triggering a response.""" - self.published_messages.append(message) - if self.publish_side_effect: - raise self.publish_side_effect - # When a message is published, simulate a response - if self.response_queue: - response = self.response_queue.pop(0) - # Give a chance for the subscriber to be registered - for subscriber in list(self.subscribers): - subscriber(response) - - async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: - """Simulate subscribing to messages.""" - self.subscribers.append(callback) - return lambda: self.subscribers.remove(callback) diff --git a/tests/devices/test_a01_channel.py b/tests/devices/test_a01_channel.py index e5bf112a..befa8951 100644 --- a/tests/devices/test_a01_channel.py +++ b/tests/devices/test_a01_channel.py @@ -11,8 +11,7 @@ RoborockMessage, RoborockMessageProtocol, ) - -from ..conftest import FakeChannel +from tests.fixtures.channel_fixtures import FakeChannel @pytest.fixture diff --git a/tests/devices/test_v1_channel.py b/tests/devices/test_v1_channel.py index 1152675f..c9c40cad 100644 --- a/tests/devices/test_v1_channel.py +++ b/tests/devices/test_v1_channel.py @@ -25,9 +25,8 @@ from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand - -from .. import mock_data -from ..conftest import FakeChannel +from tests import mock_data +from tests.fixtures.channel_fixtures import FakeChannel USER_DATA = UserData.from_dict(mock_data.USER_DATA) TEST_DEVICE_UID = "abc123" diff --git a/tests/devices/traits/a01/test_init.py b/tests/devices/traits/a01/test_init.py index 9b9c83b3..8e1cb7dd 100644 --- a/tests/devices/traits/a01/test_init.py +++ b/tests/devices/traits/a01/test_init.py @@ -8,7 +8,7 @@ from roborock.devices.traits.a01 import DyadApi, ZeoApi from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol -from tests.conftest import FakeChannel +from tests.fixtures.channel_fixtures import FakeChannel from tests.protocols.common import build_a01_message diff --git a/tests/devices/traits/b01/test_init.py b/tests/devices/traits/b01/test_init.py index 7fa89f48..0eb0f5f2 100644 --- a/tests/devices/traits/b01/test_init.py +++ b/tests/devices/traits/b01/test_init.py @@ -12,7 +12,7 @@ from roborock.exceptions import RoborockException from roborock.protocols.b01_protocol import B01_VERSION from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol -from tests.conftest import FakeChannel +from tests.fixtures.channel_fixtures import FakeChannel def build_b01_message(message: dict[Any, Any], msg_id: str = "123456789", seq: int = 2020) -> RoborockMessage: diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py index c50cd2b3..38fd7fb6 100644 --- a/tests/e2e/__init__.py +++ b/tests/e2e/__init__.py @@ -1 +1,8 @@ """End-to-end tests package.""" + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", + "tests.fixtures.local_async_fixtures", + "tests.fixtures.pahomqtt_fixtures", + "tests.fixtures.aiomqtt_fixtures", +] diff --git a/tests/e2e/test_local_session.py b/tests/e2e/test_local_session.py index 392279ff..347923da 100644 --- a/tests/e2e/test_local_session.py +++ b/tests/e2e/test_local_session.py @@ -1,17 +1,14 @@ """End-to-end tests for LocalChannel using fake sockets.""" import asyncio -from collections.abc import AsyncGenerator, Callable, Generator -from queue import Queue -from typing import Any -from unittest.mock import Mock, patch +from collections.abc import AsyncGenerator +from unittest.mock import patch import pytest from roborock.devices.local_channel import LocalChannel from roborock.protocol import create_local_decoder, create_local_encoder from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -from tests.conftest import RequestHandler from tests.mock_data import LOCAL_KEY TEST_HOST = "192.168.1.100" @@ -21,35 +18,8 @@ TEST_RANDOM = 13579 -@pytest.fixture(name="mock_create_local_connection") -def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]: - """Fixture that overrides the transport creation to wire it up to the mock socket.""" - - async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]: - protocol = protocol_factory() - - def handle_write(data: bytes) -> None: - response = request_handler(data) - if response is not None: - # Call data_received directly to avoid loop scheduling issues in test - protocol.data_received(response) - - closed = asyncio.Event() - - mock_transport = Mock() - mock_transport.write = handle_write - mock_transport.close = closed.set - mock_transport.is_reading = lambda: not closed.is_set() - - return (mock_transport, protocol) - - with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop: - mock_loop.return_value.create_connection.side_effect = create_connection - yield - - @pytest.fixture(name="local_channel") -async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]: +async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]: with patch( "roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID ): @@ -80,19 +50,23 @@ def build_response( async def test_connect( - local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes] + local_channel: LocalChannel, + local_response_queue: asyncio.Queue[bytes], + local_received_requests: asyncio.Queue[bytes], ) -> None: """Test connecting to the device.""" # Queue HELLO response with payload to ensure it can be parsed - response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)) + local_response_queue.put_nowait( + build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM) + ) await local_channel.connect() assert local_channel.is_connected - assert received_requests.qsize() == 1 + assert local_received_requests.qsize() == 1 # Verify HELLO request - request_bytes = received_requests.get() + request_bytes = await local_received_requests.get() # Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None # which causes MessageParser to fail parsing. For now we verify the raw bytes. @@ -104,17 +78,21 @@ async def test_connect( async def test_send_command( - local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes] + local_channel: LocalChannel, + local_response_queue: asyncio.Queue[bytes], + local_received_requests: asyncio.Queue[bytes], ) -> None: """Test sending a command.""" # Queue HELLO response - response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)) + local_response_queue.put_nowait( + build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM) + ) await local_channel.connect() # Clear requests from handshake - while not received_requests.empty(): - received_requests.get() + while not local_received_requests.empty(): + await local_received_requests.get() # Send command cmd_seq = 123 @@ -127,8 +105,8 @@ async def test_send_command( await local_channel.publish(msg) # Verify request - assert received_requests.qsize() == 1 - request_bytes = received_requests.get() + request_bytes = await local_received_requests.get() + assert local_received_requests.empty() # Decode request decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE) diff --git a/tests/e2e/test_mqtt_session.py b/tests/e2e/test_mqtt_session.py index e8a63843..294bf5e8 100644 --- a/tests/e2e/test_mqtt_session.py +++ b/tests/e2e/test_mqtt_session.py @@ -18,12 +18,12 @@ from roborock.protocol import MessageParser from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from tests import mqtt_packet +from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber from tests.mock_data import LOCAL_KEY -from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber @pytest.fixture(autouse=True) -def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None: +def auto_mock_mqtt_client(mock_aiomqtt_client: None) -> None: """Automatically use the mock mqtt client fixture.""" @@ -33,7 +33,7 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: @pytest.fixture(autouse=True) -def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: +def mqtt_server_fixture(mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None) -> None: """Fixture to mock the MQTT connection. This is here to pull in the mock socket pixtures into all tests used here. @@ -42,10 +42,10 @@ def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None @pytest.fixture(name="session") async def session_fixture( - push_response: Callable[[bytes], None], + push_mqtt_response: Callable[[bytes], None], ) -> AsyncGenerator[MqttSession, None]: """Fixture to create a new connected MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected try: @@ -54,12 +54,12 @@ async def session_fixture( await session.close() -async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None: +async def test_session_e2e_receive_message(push_mqtt_response: Callable[[bytes], None], session: MqttSession) -> None: """Test receiving a real Roborock message through the session.""" assert session.connected # Subscribe to the topic. We'll next construct and push a message. - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber = Subscriber() await session.subscribe("topic-1", subscriber.append) @@ -71,12 +71,13 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False) # Simulate receiving the message from the broker - push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload)) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload)) # Verify it was dispatched to the subscriber await subscriber.wait() assert len(subscriber.messages) == 1 received_payload = subscriber.messages[0] + assert isinstance(received_payload, bytes) assert received_payload == payload # Verify the message payload contents @@ -90,8 +91,8 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None async def test_session_e2e_publish_message( - push_response: Callable[[bytes], None], - received_requests: Queue, + push_mqtt_response: Callable[[bytes], None], + mqtt_received_requests: Queue, session: MqttSession, ) -> None: """Test publishing a real Roborock message.""" @@ -109,8 +110,8 @@ async def test_session_e2e_publish_message( # Verify what was sent to the broker # We expect the payload to be present in the sent bytes found = False - while not received_requests.empty(): - request = received_requests.get() + while not mqtt_received_requests.empty(): + request = mqtt_received_requests.get() if payload in request: found = True break diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mqtt_fixtures.py b/tests/fixtures/aiomqtt_fixtures.py similarity index 67% rename from tests/mqtt_fixtures.py rename to tests/fixtures/aiomqtt_fixtures.py index cb5ff17f..d9e10e74 100644 --- a/tests/mqtt_fixtures.py +++ b/tests/fixtures/aiomqtt_fixtures.py @@ -10,40 +10,11 @@ import paho.mqtt.client as mqtt import pytest -from roborock.mqtt.session import MqttParams -from tests.conftest import FakeSocketHandler +from .mqtt import FakeMqttSocketHandler -FAKE_PARAMS = MqttParams( - host="localhost", - port=1883, - tls=False, - username="username", - password="password", - timeout=10.0, -) - -class Subscriber: - """Mock subscriber class. - - We use this to hold on to received messages for verification. - """ - - def __init__(self) -> None: - self.messages: list[bytes] = [] - self._event = asyncio.Event() - - def append(self, message: bytes) -> None: - self.messages.append(message) - self._event.set() - - async def wait(self) -> None: - await asyncio.wait_for(self._event.wait(), timeout=1.0) - self._event.clear() - - -@pytest.fixture -async def mock_mqtt_client_fixture() -> AsyncGenerator[None, None]: +@pytest.fixture(name="mock_aiomqtt_client") +async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]: """Fixture to patch the MQTT underlying sync client. The tests use fake sockets, so this ensures that the async mqtt client does not @@ -94,11 +65,13 @@ def fast_backoff_fixture() -> Generator[None, None, None]: @pytest.fixture -def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]: +def push_mqtt_response( + mqtt_response_queue: Queue, fake_mqtt_socket_handler: FakeMqttSocketHandler +) -> Callable[[bytes], None]: """Fixture to push a response to the client.""" def _push(data: bytes) -> None: - response_queue.put(data) - fake_socket_handler.push_response() + mqtt_response_queue.put(data) + fake_mqtt_socket_handler.push_response() return _push diff --git a/tests/fixtures/channel_fixtures.py b/tests/fixtures/channel_fixtures.py new file mode 100644 index 00000000..1faae11c --- /dev/null +++ b/tests/fixtures/channel_fixtures.py @@ -0,0 +1,53 @@ +from collections.abc import Callable +from unittest.mock import AsyncMock, MagicMock + +from roborock.mqtt.health_manager import HealthManager +from roborock.protocols.v1_protocol import LocalProtocolVersion +from roborock.roborock_message import RoborockMessage + + +class FakeChannel: + """A fake channel that handles publish and subscribe calls.""" + + def __init__(self): + """Initialize the fake channel.""" + self.subscribers: list[Callable[[RoborockMessage], None]] = [] + self.published_messages: list[RoborockMessage] = [] + self.response_queue: list[RoborockMessage] = [] + self._is_connected = False + self.publish_side_effect: Exception | None = None + self.publish = AsyncMock(side_effect=self._publish) + self.subscribe = AsyncMock(side_effect=self._subscribe) + self.connect = AsyncMock(side_effect=self._connect) + self.close = MagicMock(side_effect=self._close) + self.protocol_version = LocalProtocolVersion.V1 + self.restart = AsyncMock() + self.health_manager = HealthManager(self.restart) + + async def _connect(self) -> None: + self._is_connected = True + + def _close(self) -> None: + self._is_connected = False + + @property + def is_connected(self) -> bool: + """Return true if connected.""" + return self._is_connected + + async def _publish(self, message: RoborockMessage) -> None: + """Simulate publishing a message and triggering a response.""" + self.published_messages.append(message) + if self.publish_side_effect: + raise self.publish_side_effect + # When a message is published, simulate a response + if self.response_queue: + response = self.response_queue.pop(0) + # Give a chance for the subscriber to be registered + for subscriber in list(self.subscribers): + subscriber(response) + + async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: + """Simulate subscribing to messages.""" + self.subscribers.append(callback) + return lambda: self.subscribers.remove(callback) diff --git a/tests/fixtures/local_async_fixtures.py b/tests/fixtures/local_async_fixtures.py new file mode 100644 index 00000000..7be1f7e5 --- /dev/null +++ b/tests/fixtures/local_async_fixtures.py @@ -0,0 +1,77 @@ +import asyncio +from collections.abc import Awaitable, Callable, Generator +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +AsyncLocalRequestHandler = Callable[[bytes], Awaitable[bytes | None]] + + +@pytest.fixture(name="local_received_requests") +def received_requests_fixture() -> asyncio.Queue[bytes]: + """Fixture that provides access to the received requests.""" + return asyncio.Queue() + + +@pytest.fixture(name="local_response_queue") +def response_queue_fixture() -> Generator[asyncio.Queue[bytes], None, None]: + """Fixture that provides access to the received requests.""" + response_queue: asyncio.Queue[bytes] = asyncio.Queue() + yield response_queue + # assert response_queue.empty(), "Not all fake responses were consumed" + + +@pytest.fixture(name="local_async_request_handler") +def local_request_handler_fixture( + local_received_requests: asyncio.Queue[bytes], local_response_queue: asyncio.Queue[bytes] +) -> AsyncLocalRequestHandler: + """Fixture records incoming requests and replies with responses from the queue.""" + + async def handle_request(client_request: bytes) -> bytes | None: + """Handle an incoming request from the client.""" + local_received_requests.put_nowait(client_request) + + # Insert a prepared response into the response buffer + if not local_response_queue.empty(): + return await local_response_queue.get() + return None + + return handle_request + + +@pytest.fixture(name="mock_async_create_local_connection") +def create_local_connection_fixture( + local_async_request_handler: AsyncLocalRequestHandler, +) -> Generator[None, None, None]: + """Fixture that overrides the transport creation to wire it up to the mock socket.""" + + tasks = [] + + async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]: + protocol = protocol_factory() + + async def handle_write(data: bytes) -> None: + response = await local_async_request_handler(data) + if response is not None: + # Call data_received directly to avoid loop scheduling issues in test + protocol.data_received(response) + + def start_handle_write(data: bytes) -> None: + tasks.append(asyncio.create_task(handle_write(data))) + + closed = asyncio.Event() + + mock_transport = Mock() + mock_transport.write = start_handle_write + mock_transport.close = closed.set + mock_transport.is_reading = lambda: not closed.is_set() + + return (mock_transport, protocol) + + with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.create_connection.side_effect = create_connection + yield + + for task in tasks: + task.cancel() diff --git a/tests/fixtures/logging.py b/tests/fixtures/logging.py new file mode 100644 index 00000000..2d0279a6 --- /dev/null +++ b/tests/fixtures/logging.py @@ -0,0 +1,60 @@ +"""Logging utilities for tests.""" + + +class CapturedRequestLog: + """Log of requests and responses for snapshot assertions. + + The log captures the raw bytes of each request and response along with + a label indicating the direction of the message. + """ + + def __init__(self) -> None: + """Initialize the request log.""" + self.entries: list[tuple[str, bytes]] = [] + + def add_log_entry(self, label: str, data: bytes) -> None: + """Add a request entry.""" + self.entries.append((label, data)) + + def __repr__(self): + """Return a string representation of the log entries. + + This assumes that the client will behave in a request-response manner, + so each request is followed by a response. If a test uses non-deterministic + message order, this may not be accurate and the test would need to decode + the raw messages and remove any ordering assumptions. + """ + lines = [] + for label, data in self.entries: + lines.append(label) + lines.extend(self._hexdump(data)) + return "\n".join(lines) + + def _hexdump(self, data: bytes, bytes_per_line: int = 16) -> list[str]: + """Print a hexdump of the given bytes object in a tcpdump/hexdump -C style. + + This makes the packets easier to read and compare in test snapshots. + + Args: + data: The bytes object to print. + bytes_per_line: The number of bytes to display per line (default is 16). + """ + + # Use '.' for non-printable characters (ASCII < 32 or > 126) + def to_printable_ascii(byte_val): + return chr(byte_val) if 32 <= byte_val <= 126 else "." + + offset = 0 + lines = [] + while offset < len(data): + chunk = data[offset : offset + bytes_per_line] + # Format the hex values, space-padded to ensure alignment + hex_values = " ".join(f"{byte:02x}" for byte in chunk) + # Pad hex string to a fixed width so ASCII column lines up + # 3 chars per byte ('xx ') for a full line of 16 bytes + padded_hex = f"{hex_values:<{bytes_per_line * 3}}" + # Format the ASCII values + ascii_values = "".join(to_printable_ascii(byte) for byte in chunk) + lines.append(f"{offset:08x} {padded_hex} |{ascii_values}|") + offset += bytes_per_line + return lines diff --git a/tests/fixtures/logging_fixtures.py b/tests/fixtures/logging_fixtures.py new file mode 100644 index 00000000..0a6ff112 --- /dev/null +++ b/tests/fixtures/logging_fixtures.py @@ -0,0 +1,126 @@ +from collections.abc import Generator +from unittest.mock import patch + +import pytest + +# Fixed timestamp for deterministic tests for asserting on message contents +FAKE_TIMESTAMP = 1755750946.721395 + + +class CapturedRequestLog: + """Log of requests and responses for snapshot assertions. + + The log captures the raw bytes of each request and response along with + a label indicating the direction of the message. + """ + + def __init__(self) -> None: + """Initialize the request log.""" + self.entries: list[tuple[str, bytes]] = [] + + def add_log_entry(self, label: str, data: bytes) -> None: + """Add a request entry.""" + self.entries.append((label, data)) + + def __repr__(self): + """Return a string representation of the log entries. + + This assumes that the client will behave in a request-response manner, + so each request is followed by a response. If a test uses non-deterministic + message order, this may not be accurate and the test would need to decode + the raw messages and remove any ordering assumptions. + """ + lines = [] + for label, data in self.entries: + lines.append(label) + lines.extend(self._hexdump(data)) + return "\n".join(lines) + + def _hexdump(self, data: bytes, bytes_per_line: int = 16) -> list[str]: + """Print a hexdump of the given bytes object in a tcpdump/hexdump -C style. + + This makes the packets easier to read and compare in test snapshots. + + Args: + data: The bytes object to print. + bytes_per_line: The number of bytes to display per line (default is 16). + """ + + # Use '.' for non-printable characters (ASCII < 32 or > 126) + def to_printable_ascii(byte_val): + return chr(byte_val) if 32 <= byte_val <= 126 else "." + + offset = 0 + lines = [] + while offset < len(data): + chunk = data[offset : offset + bytes_per_line] + # Format the hex values, space-padded to ensure alignment + hex_values = " ".join(f"{byte:02x}" for byte in chunk) + # Pad hex string to a fixed width so ASCII column lines up + # 3 chars per byte ('xx ') for a full line of 16 bytes + padded_hex = f"{hex_values:<{bytes_per_line * 3}}" + # Format the ASCII values + ascii_values = "".join(to_printable_ascii(byte) for byte in chunk) + lines.append(f"{offset:08x} {padded_hex} |{ascii_values}|") + offset += bytes_per_line + return lines + + +@pytest.fixture +def deterministic_message_fixtures() -> Generator[None, None, None]: + """Fixture to use predictable get_next_int and timestamp values for each test. + + This test mocks out the functions used to generate requests that have some + entropy such as the nonces, timestamps, and request IDs. This makes the + generated messages deterministic so we can snapshot them in a test. + """ + + # Pick an arbitrary sequence number used for outgoing requests + next_int = 9090 + + def get_next_int(min_value: int, max_value: int) -> int: + nonlocal next_int + result = next_int + next_int += 1 + if next_int > max_value: + next_int = min_value + return result + + # Pick an arbitrary timestamp used for the message encryption + timestamp = FAKE_TIMESTAMP + + def get_timestamp() -> int: + """Get a monotonically increasing timestamp for testing.""" + nonlocal timestamp + timestamp += 1 + return int(timestamp) + + # Use predictable seeds for token_bytes + token_chr = "A" + + def get_token_bytes(n: int) -> bytes: + nonlocal token_chr + result = token_chr.encode() * n + # Cycle to the next character + token_chr = chr(ord(token_chr) + 1) + if token_chr > "Z": + token_chr = "A" + return result + + with ( + patch("roborock.api.get_next_int", side_effect=get_next_int), + patch("roborock.devices.local_channel.get_next_int", side_effect=get_next_int), + patch("roborock.protocols.v1_protocol.get_next_int", side_effect=get_next_int), + patch("roborock.protocols.v1_protocol.get_timestamp", side_effect=get_timestamp), + patch("roborock.protocols.v1_protocol.secrets.token_bytes", side_effect=get_token_bytes), + patch("roborock.version_1_apis.roborock_local_client_v1.get_next_int", side_effect=get_next_int), + patch("roborock.roborock_message.get_next_int", side_effect=get_next_int), + patch("roborock.roborock_message.get_timestamp", side_effect=get_timestamp), + ): + yield + + +@pytest.fixture(name="log") +def log_fixture(deterministic_message_fixtures: None) -> CapturedRequestLog: + """Fixture that creates a captured request log.""" + return CapturedRequestLog() diff --git a/tests/fixtures/mqtt.py b/tests/fixtures/mqtt.py new file mode 100644 index 00000000..793bb3f4 --- /dev/null +++ b/tests/fixtures/mqtt.py @@ -0,0 +1,101 @@ +"""Common code for MQTT tests.""" + +import asyncio +import io +import logging +from collections.abc import Callable +from queue import Queue + +from roborock.mqtt.session import MqttParams +from roborock.roborock_message import RoborockMessage + +from .logging import CapturedRequestLog + +_LOGGER = logging.getLogger(__name__) + +# Used by fixtures to handle incoming requests and prepare responses +MqttRequestHandler = Callable[[bytes], bytes | None] + + +class FakeMqttSocketHandler: + """Fake socket used by the test to simulate a connection to the broker. + + The socket handler is used to intercept the socket send and recv calls and + populate the response buffer with data to be sent back to the client. The + handle request callback handles the incoming requests and prepares the responses. + """ + + def __init__( + self, handle_request: MqttRequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog + ) -> None: + self.response_buf = io.BytesIO() + self.handle_request = handle_request + self.response_queue = response_queue + self.log = log + + def pending(self) -> int: + """Return the number of bytes in the response buffer.""" + return len(self.response_buf.getvalue()) + + def handle_socket_recv(self, read_size: int) -> bytes: + """Intercept a client recv() and populate the buffer.""" + if self.pending() == 0: + raise BlockingIOError("No response queued") + + self.response_buf.seek(0) + data = self.response_buf.read(read_size) + _LOGGER.debug("Response: 0x%s", data.hex()) + # Consume the rest of the data in the buffer + remaining_data = self.response_buf.read() + self.response_buf = io.BytesIO(remaining_data) + return data + + def handle_socket_send(self, client_request: bytes) -> int: + """Receive an incoming request from the client.""" + _LOGGER.debug("Request: 0x%s", client_request.hex()) + self.log.add_log_entry("[mqtt >]", client_request) + if (response := self.handle_request(client_request)) is not None: + # Enqueue a response to be sent back to the client in the buffer. + # The buffer will be emptied when the client calls recv() on the socket + _LOGGER.debug("Queued: 0x%s", response.hex()) + self.log.add_log_entry("[mqtt <]", response) + self.response_buf.write(response) + return len(client_request) + + def push_response(self) -> None: + """Push a response to the client.""" + if not self.response_queue.empty(): + response = self.response_queue.get() + # Enqueue a response to be sent back to the client in the buffer. + # The buffer will be emptied when the client calls recv() on the socket + _LOGGER.debug("Queued: 0x%s", response.hex()) + self.response_buf.write(response) + + +FAKE_PARAMS = MqttParams( + host="localhost", + port=1883, + tls=False, + username="username", + password="password", + timeout=10.0, +) + + +class Subscriber: + """Mock subscriber class. + + We use this to hold on to received messages for verification. + """ + + def __init__(self) -> None: + self.messages: list[RoborockMessage | bytes] = [] + self._event = asyncio.Event() + + def append(self, message: RoborockMessage | bytes) -> None: + self.messages.append(message) + self._event.set() + + async def wait(self) -> None: + await asyncio.wait_for(self._event.wait(), timeout=1.0) + self._event.clear() diff --git a/tests/fixtures/pahomqtt_fixtures.py b/tests/fixtures/pahomqtt_fixtures.py new file mode 100644 index 00000000..b0a50038 --- /dev/null +++ b/tests/fixtures/pahomqtt_fixtures.py @@ -0,0 +1,97 @@ +"""Common code for MQTT tests.""" + +import logging +from collections.abc import Callable, Generator +from queue import Queue +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from .logging import CapturedRequestLog +from .mqtt import FakeMqttSocketHandler + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", +] + +_LOGGER = logging.getLogger(__name__) + +# Used by fixtures to handle incoming requests and prepare responses +MqttRequestHandler = Callable[[bytes], bytes | None] + + +@pytest.fixture(name="mock_paho_mqtt_create_connection") +def create_connection_fixture(mock_sock: Mock) -> Generator[None, None, None]: + """Fixture that overrides the MQTT socket creation to wire it up to the mock socket.""" + with patch("paho.mqtt.client.socket.create_connection", return_value=mock_sock): + yield + + +@pytest.fixture(name="mock_paho_mqtt_select") +def select_fixture(mock_sock: Mock, fake_mqtt_socket_handler: FakeMqttSocketHandler) -> Generator[None, None, None]: + """Fixture that overrides the MQTT client select calls to make select work on the mock socket. + + This patch select to activate our mock socket when ready with data. Internal mqtt sockets are + always ready since they are used internally to wake the select loop. Ours is ready if there + is data in the buffer. + """ + + def is_ready(sock: Any) -> bool: + return sock is not mock_sock or (fake_mqtt_socket_handler.pending() > 0) + + def handle_select(rlist: list, wlist: list, *args: Any) -> list: + return [list(filter(is_ready, rlist)), list(filter(is_ready, wlist))] + + with patch("paho.mqtt.client.select.select", side_effect=handle_select): + yield + + +@pytest.fixture(name="fake_mqtt_socket_handler") +def fake_mqtt_socket_handler_fixture( + mqtt_request_handler: MqttRequestHandler, mqtt_response_queue: Queue[bytes], log: CapturedRequestLog +) -> FakeMqttSocketHandler: + """Fixture that creates a fake MQTT broker.""" + return FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log) + + +@pytest.fixture(name="mock_sock") +def mock_sock_fixture(fake_mqtt_socket_handler: FakeMqttSocketHandler) -> Mock: + """Fixture that creates a mock socket connection and wires it to the handler.""" + mock_sock = Mock() + mock_sock.recv = fake_mqtt_socket_handler.handle_socket_recv + mock_sock.send = fake_mqtt_socket_handler.handle_socket_send + mock_sock.pending = fake_mqtt_socket_handler.pending + return mock_sock + + +@pytest.fixture(name="mqtt_received_requests") +def received_requests_fixture() -> Queue[bytes]: + """Fixture that provides access to the received requests.""" + return Queue() + + +@pytest.fixture(name="mqtt_response_queue") +def response_queue_fixture() -> Generator[Queue[bytes], None, None]: + """Fixture that provides access to the received requests.""" + response_queue: Queue[bytes] = Queue() + yield response_queue + assert response_queue.empty(), "Not all fake responses were consumed" + + +@pytest.fixture(name="mqtt_request_handler") +def mqtt_request_handler_fixture( + mqtt_received_requests: Queue[bytes], mqtt_response_queue: Queue[bytes] +) -> MqttRequestHandler: + """Fixture records incoming requests and replies with responses from the queue.""" + + def handle_request(client_request: bytes) -> bytes | None: + """Handle an incoming request from the client.""" + mqtt_received_requests.put(client_request) + + # Insert a prepared response into the response buffer + if not mqtt_response_queue.empty(): + return mqtt_response_queue.get() + return None + + return handle_request diff --git a/tests/fixtures/web_api_fixtures.py b/tests/fixtures/web_api_fixtures.py new file mode 100644 index 00000000..0b071387 --- /dev/null +++ b/tests/fixtures/web_api_fixtures.py @@ -0,0 +1,141 @@ +import re +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest +from aioresponses import aioresponses + +from tests.mock_data import HOME_DATA_RAW, HOME_DATA_SCENES_RAW, USER_DATA + + +@pytest.fixture +def skip_rate_limit() -> Generator[None, None, None]: + """Don't rate limit tests as they aren't actually hitting the api.""" + with ( + patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire"), + patch("roborock.web_api.RoborockApiClient._home_data_limiter.try_acquire"), + ): + yield + + +@pytest.fixture(name="mock_rest") +def mock_rest_fixture(skip_rate_limit: Any) -> aioresponses: + """Mock all rest endpoints so they won't hit real endpoints""" + with aioresponses() as mocked: + # Match the base URL and allow any query params + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"), + status=200, + payload={ + "code": 200, + "data": {"country": "US", "countrycode": "1", "url": "https://usiot.roborock.com"}, + "msg": "success", + }, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/login.*"), + status=200, + payload={"code": 200, "data": USER_DATA, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/loginWithCode.*"), + status=200, + payload={"code": 200, "data": USER_DATA, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/sendEmailCode.*"), + status=200, + payload={"code": 200, "data": None, "msg": "success"}, + ) + mocked.get( + re.compile(r"https://.*iot\.roborock\.com/api/v1/getHomeDetail.*"), + status=200, + payload={ + "code": 200, + "data": {"deviceListOrder": None, "id": 123456, "name": "My Home", "rrHomeId": 123456, "tuyaHomeId": 0}, + "msg": "success", + }, + ) + mocked.get( + re.compile(r"https://api-.*\.roborock\.com/v2/user/homes*"), + status=200, + payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True}, + ) + mocked.post( + re.compile(r"https://api-.*\.roborock\.com/nc/prepare"), + status=200, + payload={ + "api": None, + "result": {"r": "US", "s": "ffffff", "t": "eOf6d2BBBB"}, + "status": "ok", + "success": True, + }, + ) + + mocked.get( + re.compile(r"https://api-.*\.roborock\.com/user/devices/newadd/*"), + status=200, + payload={ + "api": "获取新增设备信息", + "result": { + "activeTime": 1737724598, + "attribute": None, + "cid": None, + "createTime": 0, + "deviceStatus": None, + "duid": "rand_duid", + "extra": "{}", + "f": False, + "featureSet": "0", + "fv": "02.16.12", + "iconUrl": "", + "lat": None, + "localKey": "random_lk", + "lon": None, + "name": "S7", + "newFeatureSet": "0000000000002000", + "online": True, + "productId": "rand_prod_id", + "pv": "1.0", + "roomId": None, + "runtimeEnv": None, + "setting": None, + "share": False, + "shareTime": None, + "silentOtaSwitch": False, + "sn": "Rand_sn", + "timeZoneId": "America/New_York", + "tuyaMigrated": False, + "tuyaUuid": None, + }, + "status": "ok", + "success": True, + }, + ) + mocked.get( + re.compile(r"https://api-.*\.roborock\.com/user/scene/device/.*"), + status=200, + payload={"api": None, "code": 200, "result": HOME_DATA_SCENES_RAW, "status": "ok", "success": True}, + ) + mocked.post( + re.compile(r"https://api-.*\.roborock\.com/user/scene/.*/execute"), + status=200, + payload={"api": None, "code": 200, "result": None, "status": "ok", "success": True}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v4/email/code/send.*"), + status=200, + payload={"code": 200, "data": None, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v3/key/sign.*"), + status=200, + payload={"code": 200, "data": {"k": "mock_k"}, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v4/auth/email/login/code.*"), + status=200, + payload={"code": 200, "data": USER_DATA, "msg": "success"}, + ) + yield mocked diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 3f8a9485..24e6fbbd 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -13,16 +13,27 @@ from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session from roborock.mqtt.session import MqttSessionException, MqttSessionUnauthorized from tests import mqtt_packet -from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber +from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", + "tests.fixtures.pahomqtt_fixtures", + "tests.fixtures.aiomqtt_fixtures", +] @pytest.fixture(autouse=True) -def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: +def mqtt_server_fixture( + mock_paho_mqtt_create_connection: None, + mock_paho_mqtt_select: None, +) -> None: """Fixture to prepare a fake MQTT server.""" @pytest.fixture(autouse=True) -def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None: +def auto_mock_aiomqtt_client( + mock_aiomqtt_client: None, +) -> None: """Automatically use the mock mqtt client fixture.""" @@ -31,9 +42,12 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: """Automatically use the fast backoff fixture.""" -@pytest.fixture -def mock_mqtt_client() -> Generator[AsyncMock, None, None]: - """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" +@pytest.fixture(name="mqtt_client_lite") +def mqtt_client_lite_fixture() -> Generator[AsyncMock, None, None]: + """A fixture that provides a mocked aiomqtt Client. + + This is lighter weight that `mock_aiomqtt_client` that uses real sockets. + """ mock_client = AsyncMock() mock_client.messages = FakeAsyncIterator() @@ -48,38 +62,38 @@ def mock_mqtt_client() -> Generator[AsyncMock, None, None]: yield mock_client -async def test_session(push_response: Callable[[bytes], None]) -> None: +async def test_session(push_mqtt_response: Callable[[bytes], None]) -> None: """Test the MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber1 = Subscriber() unsub1 = await session.subscribe("topic-1", subscriber1.append) - push_response(mqtt_packet.gen_suback(mid=2)) + push_mqtt_response(mqtt_packet.gen_suback(mid=2)) subscriber2 = Subscriber() await session.subscribe("topic-2", subscriber2.append) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await subscriber1.wait() assert subscriber1.messages == [b"12345"] assert not subscriber2.messages - push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) await subscriber2.wait() assert subscriber2.messages == [b"67890"] - push_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) await subscriber1.wait() assert subscriber1.messages == [b"12345", b"ABC"] assert subscriber2.messages == [b"67890"] # Messages are no longer received after unsubscribing unsub1() - push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) assert subscriber1.messages == [b"12345", b"ABC"] assert session.connected @@ -87,12 +101,12 @@ async def test_session(push_response: Callable[[bytes], None]) -> None: assert not session.connected -async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> None: +async def test_session_no_subscribers(push_mqtt_response: Callable[[bytes], None]) -> None: """Test the MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) - push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected @@ -100,13 +114,13 @@ async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> assert not session.connected -async def test_publish_command(push_response: Callable[[bytes], None]) -> None: +async def test_publish_command(push_mqtt_response: Callable[[bytes], None]) -> None: """Test publishing during an MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await session.publish("topic-1", message=b"payload") assert session.connected @@ -129,13 +143,13 @@ async def __anext__(self) -> None: await asyncio.sleep(1) -async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None: +async def test_publish_failure(mqtt_client_lite: AsyncMock) -> None: """Test an MQTT error is received when publishing a message.""" session = await create_mqtt_session(FAKE_PARAMS) assert session.connected - mock_mqtt_client.publish.side_effect = aiomqtt.MqttError + mqtt_client_lite.publish.side_effect = aiomqtt.MqttError with pytest.raises(MqttSessionException, match="Error publishing message"): await session.publish("topic-1", message=b"payload") @@ -143,13 +157,13 @@ async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None: await session.close() -async def test_subscribe_failure(mock_mqtt_client: AsyncMock) -> None: +async def test_subscribe_failure(mqtt_client_lite: AsyncMock) -> None: """Test an MQTT error while subscribing.""" session = await create_mqtt_session(FAKE_PARAMS) assert session.connected - mock_mqtt_client.subscribe.side_effect = aiomqtt.MqttError + mqtt_client_lite.subscribe.side_effect = aiomqtt.MqttError subscriber1 = Subscriber() with pytest.raises(MqttSessionException, match="Error subscribing to topic"): @@ -159,20 +173,20 @@ async def test_subscribe_failure(mock_mqtt_client: AsyncMock) -> None: await session.close() -async def test_restart(push_response: Callable[[bytes], None]) -> None: +async def test_restart(push_mqtt_response: Callable[[bytes], None]) -> None: """Test restarting the MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected # Subscribe to a topic - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber = Subscriber() await session.subscribe("topic-1", subscriber.append) # Verify we can receive messages - push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=b"12345")) await subscriber.wait() assert subscriber.messages == [b"12345"] @@ -184,20 +198,20 @@ async def test_restart(push_response: Callable[[bytes], None]) -> None: await asyncio.sleep(0.01) # We need to queue up a new connack for the reconnection - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) # And a suback for the resubscription. Since we created a new client, # the message ID resets to 1. - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) - push_response(mqtt_packet.gen_publish("topic-1", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=4, payload=b"67890")) await subscriber.wait() assert subscriber.messages == [b"12345", b"67890"] await session.close() -async def test_idle_timeout_resubscribe(mock_mqtt_client: AsyncMock) -> None: +async def test_idle_timeout_resubscribe(mqtt_client_lite: AsyncMock) -> None: """Test that resubscribing before idle timeout cancels the unsubscribe.""" # Create session with idle timeout @@ -220,12 +234,12 @@ async def test_idle_timeout_resubscribe(mock_mqtt_client: AsyncMock) -> None: await asyncio.sleep(0.01) # unsubscribe should NOT have been called because we resubscribed - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() await session.close() -async def test_idle_timeout_unsubscribe(mock_mqtt_client: AsyncMock) -> None: +async def test_idle_timeout_unsubscribe(mqtt_client_lite: AsyncMock) -> None: """Test that unsubscribe happens after idle timeout expires.""" # Create session with very short idle timeout for fast test @@ -244,12 +258,12 @@ async def test_idle_timeout_unsubscribe(mock_mqtt_client: AsyncMock) -> None: await asyncio.sleep(0.1) # unsubscribe should have been called after idle timeout - mock_mqtt_client.unsubscribe.assert_called_once_with(topic) + mqtt_client_lite.unsubscribe.assert_called_once_with(topic) await session.close() -async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> None: +async def test_idle_timeout_multiple_callbacks(mqtt_client_lite: AsyncMock) -> None: """Test that unsubscribe is delayed when multiple subscribers exist.""" # Create session with very short idle timeout for fast test @@ -271,7 +285,7 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N await asyncio.sleep(0.1) # unsubscribe should NOT have been called because subscriber2 is still active - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() # Unsubscribe second callback (NOW timer should start) unsub2() @@ -280,12 +294,12 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N await asyncio.sleep(0.1) # Now unsubscribe should have been called - mock_mqtt_client.unsubscribe.assert_called_once_with(topic) + mqtt_client_lite.unsubscribe.assert_called_once_with(topic) await session.close() -async def test_subscription_reuse(mock_mqtt_client: AsyncMock) -> None: +async def test_subscription_reuse(mqtt_client_lite: AsyncMock) -> None: """Test that subscriptions are reused and not duplicated.""" session = RoborockMqttSession(FAKE_PARAMS) await session.start() @@ -296,32 +310,32 @@ async def test_subscription_reuse(mock_mqtt_client: AsyncMock) -> None: unsub1 = await session.subscribe("topic1", cb1) # Verify subscribe called - mock_mqtt_client.subscribe.assert_called_with("topic1") - mock_mqtt_client.subscribe.reset_mock() + mqtt_client_lite.subscribe.assert_called_with("topic1") + mqtt_client_lite.subscribe.reset_mock() # 2. Second subscription (same topic) cb2 = Mock() unsub2 = await session.subscribe("topic1", cb2) # Verify subscribe NOT called - mock_mqtt_client.subscribe.assert_not_called() + mqtt_client_lite.subscribe.assert_not_called() # 3. Unsubscribe one unsub1() # Verify unsubscribe NOT called (still have cb2) - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() # 4. Unsubscribe second (starts idle timer) unsub2() # Verify unsubscribe NOT called yet (idle) - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() # 5. Resubscribe during idle cb3 = Mock() _ = await session.subscribe("topic1", cb3) # Verify subscribe NOT called (reused) - mock_mqtt_client.subscribe.assert_not_called() + mqtt_client_lite.subscribe.assert_not_called() await session.close() @@ -365,7 +379,7 @@ async def test_connect_failure( await create_mqtt_session(FAKE_PARAMS) -async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: +async def test_diagnostics_data(push_mqtt_response: Callable[[bytes], None]) -> None: """Test the MQTT session.""" diagnostics = Diagnostics() @@ -373,7 +387,7 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: params = copy.deepcopy(FAKE_PARAMS) params.diagnostics = diagnostics - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(params) assert session.connected @@ -386,24 +400,24 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: assert data.get("dispatch_message_count") is None assert data.get("close") is None - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber1 = Subscriber() unsub1 = await session.subscribe("topic-1", subscriber1.append) - push_response(mqtt_packet.gen_suback(mid=2)) + push_mqtt_response(mqtt_packet.gen_suback(mid=2)) subscriber2 = Subscriber() await session.subscribe("topic-2", subscriber2.append) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await subscriber1.wait() assert subscriber1.messages == [b"12345"] assert not subscriber2.messages - push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) await subscriber2.wait() assert subscriber2.messages == [b"67890"] - push_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) await subscriber1.wait() assert subscriber1.messages == [b"12345", b"ABC"] assert subscriber2.messages == [b"67890"] @@ -418,7 +432,7 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: # Messages are no longer received after unsubscribing unsub1() - push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) assert subscriber1.messages == [b"12345", b"ABC"] assert session.connected diff --git a/tests/test_a01_api.py b/tests/test_a01_api.py index 6c1f7dac..911a1791 100644 --- a/tests/test_a01_api.py +++ b/tests/test_a01_api.py @@ -20,9 +20,10 @@ RoborockZeoProtocol, ) from roborock.version_a01_apis import RoborockMqttClientA01 +from tests.fixtures.logging import CapturedRequestLog +from tests.protocols.common import build_a01_message from . import mqtt_packet -from .conftest import QUEUE_TIMEOUT, CapturedRequestLog from .mock_data import ( HOME_DATA_RAW, LOCAL_KEY, @@ -31,11 +32,16 @@ WASHER_PRODUCT, ZEO_ONE_DEVICE, ) -from .protocols.common import build_a01_message +QUEUE_TIMEOUT = 10 RELEASE_TIMEOUT = 2 +pytest_plugins = [ + "tests.fixtures.pahomqtt_fixtures", +] + + @pytest.fixture(name="category") def category_fixture() -> RoborockCategory: return RoborockCategory.WASHING_MACHINE @@ -43,8 +49,8 @@ def category_fixture() -> RoborockCategory: @pytest.fixture(name="a01_mqtt_client") async def a01_mqtt_client_fixture( - mock_create_connection: None, - mock_select: None, + mock_paho_mqtt_create_connection: None, + mock_paho_mqtt_select: None, category: RoborockCategory, ) -> AsyncGenerator[RoborockMqttClientA01, None]: user_data = UserData.from_dict(USER_DATA) @@ -74,15 +80,15 @@ async def a01_mqtt_client_fixture( @pytest.fixture(name="connected_a01_mqtt_client") async def connected_a01_mqtt_client_fixture( - response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 + mqtt_response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 ) -> AsyncGenerator[RoborockMqttClientA01, None]: - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - response_queue.put(mqtt_packet.gen_suback(1, 0)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) + mqtt_response_queue.put(mqtt_packet.gen_suback(1, 0)) await a01_mqtt_client.async_connect() yield a01_mqtt_client -async def test_async_connect(received_requests: Queue, connected_a01_mqtt_client: RoborockMqttClientA01) -> None: +async def test_async_connect(mqtt_received_requests: Queue, connected_a01_mqtt_client: RoborockMqttClientA01) -> None: """Test connecting to the MQTT broker.""" assert connected_a01_mqtt_client.is_connected() @@ -95,20 +101,20 @@ async def test_async_connect(received_requests: Queue, connected_a01_mqtt_client # Broker received a connect and subscribe. Disconnect packet is not # guaranteed to be captured by the time the async_disconnect returns - assert received_requests.qsize() >= 2 # Connect and Subscribe + assert mqtt_received_requests.qsize() >= 2 # Connect and Subscribe async def test_connect_failure( - received_requests: Queue, response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 ) -> None: """Test the broker responding with a connect failure.""" - response_queue.put(mqtt_packet.gen_connack(rc=1)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=1)) with pytest.raises(RoborockException, match="Failed to connect"): await a01_mqtt_client.async_connect() assert not a01_mqtt_client.is_connected() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt async def test_disconnect_already_disconnected(connected_a01_mqtt_client: RoborockMqttClientA01) -> None: @@ -141,11 +147,11 @@ async def test_async_release(connected_a01_mqtt_client: RoborockMqttClientA01) - async def test_subscribe_failure( - received_requests: Queue, response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 ) -> None: """Test the broker responding with the wrong message type on subscribe.""" - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) with ( patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), @@ -153,7 +159,7 @@ async def test_subscribe_failure( ): await a01_mqtt_client.async_connect() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt # NOTE: The client is "connected" but not "subscribed" and cannot recover # from this state without disconnecting first. This can likely be improved. @@ -162,7 +168,7 @@ async def test_subscribe_failure( # Attempting to reconnect is a no-op since the client already thinks it is connected await a01_mqtt_client.async_connect() assert a01_mqtt_client.is_connected() - assert received_requests.qsize() == 1 + assert mqtt_received_requests.qsize() == 1 def build_rpc_response(message: dict[Any, Any]) -> bytes: @@ -171,8 +177,8 @@ def build_rpc_response(message: dict[Any, Any]) -> bytes: async def test_update_zeo_values( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_a01_mqtt_client: RoborockMqttClientA01, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -189,7 +195,7 @@ async def test_update_zeo_values( 218: 0, # Washing left. Testing zero int value } ) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) data = await connected_a01_mqtt_client.update_values( [ @@ -214,8 +220,8 @@ async def test_update_zeo_values( @pytest.mark.parametrize("category", [RoborockCategory.WET_DRY_VAC]) async def test_update_dyad_values( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_a01_mqtt_client: RoborockMqttClientA01, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -231,7 +237,7 @@ async def test_update_dyad_values( 224: 0, # AUTO_DRY_MODE off } ) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) data = await connected_a01_mqtt_client.update_values( [ @@ -253,25 +259,25 @@ async def test_update_dyad_values( async def test_set_value( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_a01_mqtt_client: RoborockMqttClientA01, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, ) -> None: """Test sending an arbitrary MQTT message and parsing the response.""" # Clear existing messages received during setup - assert received_requests.qsize() == 2 - assert received_requests.get(block=True, timeout=QUEUE_TIMEOUT) - assert received_requests.get(block=True, timeout=QUEUE_TIMEOUT) - assert received_requests.empty() + assert mqtt_received_requests.qsize() == 2 + assert mqtt_received_requests.get(block=True, timeout=QUEUE_TIMEOUT) + assert mqtt_received_requests.get(block=True, timeout=QUEUE_TIMEOUT) + assert mqtt_received_requests.empty() # Prepare the response message message = build_rpc_response({}) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) await connected_a01_mqtt_client.set_value(RoborockZeoProtocol.STATE, "spinning") - assert received_requests.get(block=True) + assert mqtt_received_requests.get(block=True) assert snapshot == log diff --git a/tests/test_api.py b/tests/test_api.py index b428d2fd..2f1b3c92 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,6 +23,7 @@ from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from roborock.version_1_apis import RoborockMqttClientV1 from roborock.web_api import PreparedRequest, RoborockApiClient +from tests.fixtures.logging import CapturedRequestLog from tests.mock_data import ( BASE_URL_REQUEST, GET_CODE_RESPONSE, @@ -34,7 +35,33 @@ ) from . import mqtt_packet -from .conftest import CapturedRequestLog + +QUEUE_TIMEOUT = 10 + +pytest_plugins = [ + "tests.fixtures.pahomqtt_fixtures", +] + + +@pytest.fixture(name="mqtt_client") +async def mqtt_client( + mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None +) -> AsyncGenerator[RoborockMqttClientV1, None]: + user_data = UserData.from_dict(USER_DATA) + home_data = HomeData.from_dict(HOME_DATA_RAW) + device_info = DeviceData( + device=home_data.devices[0], + model=home_data.products[0].model, + ) + client = RoborockMqttClientV1(user_data, device_info, queue_timeout=QUEUE_TIMEOUT) + try: + yield client + finally: + if not client.is_connected(): + try: + await client.async_release() + except Exception: + pass def test_can_create_prepared_request(): @@ -143,10 +170,10 @@ async def test_get_prop(): @pytest.fixture(name="connected_mqtt_client") async def connected_mqtt_client_fixture( - response_queue: Queue, mqtt_client: RoborockMqttClientV1 + mqtt_response_queue: Queue, mqtt_client: RoborockMqttClientV1 ) -> AsyncGenerator[RoborockMqttClientV1, None]: - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - response_queue.put(mqtt_packet.gen_suback(1, 0)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) + mqtt_response_queue.put(mqtt_packet.gen_suback(1, 0)) await mqtt_client.async_connect() yield mqtt_client if mqtt_client.is_connected(): @@ -156,7 +183,7 @@ async def connected_mqtt_client_fixture( pass -async def test_async_connect(received_requests: Queue, connected_mqtt_client: RoborockMqttClientV1) -> None: +async def test_async_connect(mqtt_received_requests: Queue, connected_mqtt_client: RoborockMqttClientV1) -> None: """Test connecting to the MQTT broker.""" assert connected_mqtt_client.is_connected() @@ -169,20 +196,20 @@ async def test_async_connect(received_requests: Queue, connected_mqtt_client: Ro # Broker received a connect and subscribe. Disconnect packet is not # guaranteed to be captured by the time the async_disconnect returns - assert received_requests.qsize() >= 2 # Connect and Subscribe + assert mqtt_received_requests.qsize() >= 2 # Connect and Subscribe async def test_connect_failure_response( - received_requests: Queue, response_queue: Queue, mqtt_client: RoborockMqttClientV1 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, mqtt_client: RoborockMqttClientV1 ) -> None: """Test the broker responding with a connect failure.""" - response_queue.put(mqtt_packet.gen_connack(rc=1)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=1)) with pytest.raises(RoborockException, match="Failed to connect"): await mqtt_client.async_connect() assert not mqtt_client.is_connected() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt async def test_disconnect_already_disconnected(connected_mqtt_client: RoborockMqttClientV1) -> None: @@ -209,8 +236,8 @@ async def test_disconnect_failure(connected_mqtt_client: RoborockMqttClientV1) - async def test_disconnect_failure_response( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_mqtt_client: RoborockMqttClientV1, caplog: pytest.LogCaptureFixture, ) -> None: @@ -218,7 +245,7 @@ async def test_disconnect_failure_response( # Enqueue a failed message -- however, the client does not process any # further messages and there is no parsing error, and no failed log messages. - response_queue.put(mqtt_packet.gen_disconnect(reason_code=1)) + mqtt_response_queue.put(mqtt_packet.gen_disconnect(reason_code=1)) assert connected_mqtt_client.is_connected() with caplog.at_level(logging.ERROR): await connected_mqtt_client.async_disconnect() @@ -233,19 +260,18 @@ async def test_async_release(connected_mqtt_client: RoborockMqttClientV1) -> Non async def test_subscribe_failure( - received_requests: Queue, response_queue: Queue, mqtt_client: RoborockMqttClientV1 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, mqtt_client: RoborockMqttClientV1 ) -> None: """Test the broker responding with the wrong message type on subscribe.""" - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) with ( patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), pytest.raises(RoborockException, match="Failed to subscribe"), ): await mqtt_client.async_connect() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt # NOTE: The client is "connected" but not "subscribed" and cannot recover # from this state without disconnecting first. This can likely be improved. @@ -254,7 +280,7 @@ async def test_subscribe_failure( # Attempting to reconnect is a no-op since the client already thinks it is connected await mqtt_client.async_connect() assert mqtt_client.is_connected() - assert received_requests.qsize() == 1 + assert mqtt_received_requests.qsize() == 1 def build_rpc_response(message: dict[str, Any]) -> bytes: @@ -276,8 +302,8 @@ def build_rpc_response(message: dict[str, Any]) -> bytes: async def test_get_room_mapping( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_mqtt_client: RoborockMqttClientV1, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -291,7 +317,7 @@ async def test_get_room_mapping( "result": [[16, "2362048"], [17, "2362044"]], } ) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id): room_mapping = await connected_mqtt_client.get_room_mapping() diff --git a/tests/test_local_api_v1.py b/tests/test_local_api_v1.py index 281f7f9a..ad264554 100644 --- a/tests/test_local_api_v1.py +++ b/tests/test_local_api_v1.py @@ -1,23 +1,122 @@ """Tests for the Roborock Local Client V1.""" +import asyncio import json -from collections.abc import AsyncGenerator +import logging +from asyncio import Protocol +from collections.abc import AsyncGenerator, Callable, Generator from queue import Queue from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import syrupy -from roborock.data import RoomMapping +from roborock import HomeData +from roborock.data import DeviceData, RoomMapping from roborock.exceptions import RoborockException from roborock.protocol import MessageParser from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -from roborock.version_1_apis import RoborockLocalClientV1 +from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 +from tests.fixtures.logging import CapturedRequestLog +from tests.mock_data import HOME_DATA_RAW, TEST_LOCAL_API_HOST -from .conftest import CapturedRequestLog from .mock_data import LOCAL_KEY +_LOGGER = logging.getLogger(__name__) + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", +] + +QUEUE_TIMEOUT = 10 + +LocalRequestHandler = Callable[[bytes], bytes | None] + + +@pytest.fixture(name="local_received_requests") +def received_requests_fixture() -> Queue[bytes]: + """Fixture that provides access to the received requests.""" + return Queue() + + +@pytest.fixture(name="local_response_queue") +def response_queue_fixture() -> Generator[Queue[bytes], None, None]: + """Fixture that provides access to the received requests.""" + response_queue: Queue[bytes] = Queue() + yield response_queue + assert response_queue.empty(), "Not all fake responses were consumed" + + +@pytest.fixture(name="local_request_handler") +def local_request_handler_fixture( + local_received_requests: Queue[bytes], local_response_queue: Queue[bytes] +) -> LocalRequestHandler: + """Fixture records incoming requests and replies with responses from the queue.""" + + def handle_request(client_request: bytes) -> bytes | None: + """Handle an incoming request from the client.""" + local_received_requests.put(client_request) + + # Insert a prepared response into the response buffer + if not local_response_queue.empty(): + return local_response_queue.get() + return None + + return handle_request + + +@pytest.fixture(name="mock_create_local_connection") +def create_local_connection_fixture( + local_request_handler: LocalRequestHandler, log: CapturedRequestLog +) -> Generator[None, None, None]: + """Fixture that overrides the transport creation to wire it up to the mock socket.""" + + async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]: + protocol = protocol_factory() + + def handle_write(data: bytes) -> None: + _LOGGER.debug("Received: %s", data) + response = local_request_handler(data) + log.add_log_entry("[local >]", data) + if response is not None: + _LOGGER.debug("Replying with %s", response) + log.add_log_entry("[local <]", response) + loop = asyncio.get_running_loop() + loop.call_soon(protocol.data_received, response) + + closed = asyncio.Event() + + mock_transport = Mock() + mock_transport.write = handle_write + mock_transport.close = closed.set + mock_transport.is_reading = lambda: not closed.is_set() + + return (mock_transport, "proto") + + with patch("roborock.version_1_apis.roborock_local_client_v1.get_running_loop") as mock_loop: + mock_loop.return_value.create_connection.side_effect = create_connection + yield + + +@pytest.fixture(name="local_client") +async def local_client_fixture(mock_create_local_connection: None) -> AsyncGenerator[RoborockLocalClientV1, None]: + home_data = HomeData.from_dict(HOME_DATA_RAW) + device_info = DeviceData( + device=home_data.devices[0], + model=home_data.products[0].model, + host=TEST_LOCAL_API_HOST, + ) + client = RoborockLocalClientV1(device_info, queue_timeout=QUEUE_TIMEOUT) + try: + yield client + finally: + if not client.is_connected(): + try: + await client.async_release() + except Exception: + pass + def build_rpc_response(seq: int, message: dict[str, Any]) -> bytes: """Build an encoded RPC response message.""" @@ -45,18 +144,18 @@ def build_raw_response(protocol: RoborockMessageProtocol, seq: int, payload: byt async def test_async_connect( local_client: RoborockLocalClientV1, - received_requests: Queue, - response_queue: Queue, + local_received_requests: Queue, + local_response_queue: Queue, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, ) -> None: """Test that we can connect to the Roborock device.""" - response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) - response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) await local_client.async_connect() assert local_client.is_connected() - assert received_requests.qsize() == 2 + assert local_received_requests.qsize() == 2 await local_client.async_disconnect() assert not local_client.is_connected() @@ -66,18 +165,18 @@ async def test_async_connect( @pytest.fixture(name="connected_local_client") async def connected_local_client_fixture( - response_queue: Queue, + local_response_queue: Queue, local_client: RoborockLocalClientV1, ) -> AsyncGenerator[RoborockLocalClientV1, None]: - response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) - response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) await local_client.async_connect() yield local_client async def test_get_room_mapping( - received_requests: Queue, - response_queue: Queue, + local_received_requests: Queue, + local_response_queue: Queue, connected_local_client: RoborockLocalClientV1, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -93,7 +192,7 @@ async def test_get_room_mapping( "result": [[16, "2362048"], [17, "2362044"]], }, ) - response_queue.put(message) + local_response_queue.put(message) with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id): room_mapping = await connected_local_client.get_room_mapping() @@ -107,8 +206,8 @@ async def test_get_room_mapping( async def test_retry_request( - received_requests: Queue, - response_queue: Queue, + local_received_requests: Queue, + local_response_queue: Queue, connected_local_client: RoborockLocalClientV1, ) -> None: """Test sending an arbitrary MQTT message and parsing the response.""" @@ -122,7 +221,7 @@ async def test_retry_request( "result": "retry", }, ) - response_queue.put(retry_message) + local_response_queue.put(retry_message) with ( patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id), diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 796fd008..b280277c 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -1,12 +1,24 @@ import re +from typing import Any import aiohttp +import pytest from aioresponses.compat import normalize_url from roborock import HomeData, HomeDataScene, UserData from roborock.web_api import IotLoginInfo, RoborockApiClient from tests.mock_data import HOME_DATA_RAW, USER_DATA +pytest_plugins = [ + "tests.fixtures.web_api_fixtures", +] + + +@pytest.fixture(autouse=True) +def auto_mock_rest_fixture(mock_rest: Any) -> None: + """Auto use the mock rest fixture for all tests in this module.""" + pass + async def test_pass_login_flow() -> None: """Test that we can login with a password and we get back the correct userdata object.""" From 6f6665eed2392c2d4235777346c195aae2e0de0e Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 22 Dec 2025 08:27:58 -0800 Subject: [PATCH 2/4] chore: Address co-pilot review feedback --- tests/fixtures/local_async_fixtures.py | 9 +++++++-- tests/fixtures/pahomqtt_fixtures.py | 2 +- tests/test_local_api_v1.py | 9 +++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/local_async_fixtures.py b/tests/fixtures/local_async_fixtures.py index 7be1f7e5..394dbca5 100644 --- a/tests/fixtures/local_async_fixtures.py +++ b/tests/fixtures/local_async_fixtures.py @@ -1,10 +1,14 @@ import asyncio +import logging +import warnings from collections.abc import Awaitable, Callable, Generator from typing import Any from unittest.mock import Mock, patch import pytest +_LOGGER = logging.getLogger(__name__) + AsyncLocalRequestHandler = Callable[[bytes], Awaitable[bytes | None]] @@ -16,10 +20,11 @@ def received_requests_fixture() -> asyncio.Queue[bytes]: @pytest.fixture(name="local_response_queue") def response_queue_fixture() -> Generator[asyncio.Queue[bytes], None, None]: - """Fixture that provides access to the received requests.""" + """Fixture that provides a queue of responses to be sent to the client.""" response_queue: asyncio.Queue[bytes] = asyncio.Queue() yield response_queue - # assert response_queue.empty(), "Not all fake responses were consumed" + if not response_queue.empty(): + warnings.warn("Some enqueued local device responses were not consumed during the test") @pytest.fixture(name="local_async_request_handler") diff --git a/tests/fixtures/pahomqtt_fixtures.py b/tests/fixtures/pahomqtt_fixtures.py index b0a50038..97655f3d 100644 --- a/tests/fixtures/pahomqtt_fixtures.py +++ b/tests/fixtures/pahomqtt_fixtures.py @@ -73,7 +73,7 @@ def received_requests_fixture() -> Queue[bytes]: @pytest.fixture(name="mqtt_response_queue") def response_queue_fixture() -> Generator[Queue[bytes], None, None]: - """Fixture that provides access to the received requests.""" + """Fixture that provides a queue for enqueueing responses to be sent to the client under test.""" response_queue: Queue[bytes] = Queue() yield response_queue assert response_queue.empty(), "Not all fake responses were consumed" diff --git a/tests/test_local_api_v1.py b/tests/test_local_api_v1.py index ad264554..11e91492 100644 --- a/tests/test_local_api_v1.py +++ b/tests/test_local_api_v1.py @@ -3,11 +3,11 @@ import asyncio import json import logging -from asyncio import Protocol from collections.abc import AsyncGenerator, Callable, Generator from queue import Queue from typing import Any from unittest.mock import Mock, patch +import warnings import pytest import syrupy @@ -42,10 +42,11 @@ def received_requests_fixture() -> Queue[bytes]: @pytest.fixture(name="local_response_queue") def response_queue_fixture() -> Generator[Queue[bytes], None, None]: - """Fixture that provides access to the received requests.""" + """Fixture that provides a queue for enqueueing responses to be sent back to the client under test.""" response_queue: Queue[bytes] = Queue() yield response_queue - assert response_queue.empty(), "Not all fake responses were consumed" + if not response_queue.empty(): + warnings.warn("Not all fake responses were consumed") @pytest.fixture(name="local_request_handler") @@ -72,7 +73,7 @@ def create_local_connection_fixture( ) -> Generator[None, None, None]: """Fixture that overrides the transport creation to wire it up to the mock socket.""" - async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]: + async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args) -> tuple[Any, Any]: protocol = protocol_factory() def handle_write(data: bytes) -> None: From db84b17a3ab79721cba62e5b96d48bca970e82c9 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 22 Dec 2025 08:29:09 -0800 Subject: [PATCH 3/4] chore: fix lint errors --- tests/test_local_api_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_local_api_v1.py b/tests/test_local_api_v1.py index 11e91492..c5209423 100644 --- a/tests/test_local_api_v1.py +++ b/tests/test_local_api_v1.py @@ -3,11 +3,11 @@ import asyncio import json import logging +import warnings from collections.abc import AsyncGenerator, Callable, Generator from queue import Queue from typing import Any from unittest.mock import Mock, patch -import warnings import pytest import syrupy From ae3e9f7c7f379bb61932438adeeba6c40971714c Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Mon, 22 Dec 2025 08:31:03 -0800 Subject: [PATCH 4/4] chore: Remove duplicate captured request log --- tests/fixtures/logging_fixtures.py | 61 +----------------------------- 1 file changed, 2 insertions(+), 59 deletions(-) diff --git a/tests/fixtures/logging_fixtures.py b/tests/fixtures/logging_fixtures.py index 0a6ff112..feb9e82f 100644 --- a/tests/fixtures/logging_fixtures.py +++ b/tests/fixtures/logging_fixtures.py @@ -3,69 +3,12 @@ import pytest +from .logging import CapturedRequestLog + # Fixed timestamp for deterministic tests for asserting on message contents FAKE_TIMESTAMP = 1755750946.721395 -class CapturedRequestLog: - """Log of requests and responses for snapshot assertions. - - The log captures the raw bytes of each request and response along with - a label indicating the direction of the message. - """ - - def __init__(self) -> None: - """Initialize the request log.""" - self.entries: list[tuple[str, bytes]] = [] - - def add_log_entry(self, label: str, data: bytes) -> None: - """Add a request entry.""" - self.entries.append((label, data)) - - def __repr__(self): - """Return a string representation of the log entries. - - This assumes that the client will behave in a request-response manner, - so each request is followed by a response. If a test uses non-deterministic - message order, this may not be accurate and the test would need to decode - the raw messages and remove any ordering assumptions. - """ - lines = [] - for label, data in self.entries: - lines.append(label) - lines.extend(self._hexdump(data)) - return "\n".join(lines) - - def _hexdump(self, data: bytes, bytes_per_line: int = 16) -> list[str]: - """Print a hexdump of the given bytes object in a tcpdump/hexdump -C style. - - This makes the packets easier to read and compare in test snapshots. - - Args: - data: The bytes object to print. - bytes_per_line: The number of bytes to display per line (default is 16). - """ - - # Use '.' for non-printable characters (ASCII < 32 or > 126) - def to_printable_ascii(byte_val): - return chr(byte_val) if 32 <= byte_val <= 126 else "." - - offset = 0 - lines = [] - while offset < len(data): - chunk = data[offset : offset + bytes_per_line] - # Format the hex values, space-padded to ensure alignment - hex_values = " ".join(f"{byte:02x}" for byte in chunk) - # Pad hex string to a fixed width so ASCII column lines up - # 3 chars per byte ('xx ') for a full line of 16 bytes - padded_hex = f"{hex_values:<{bytes_per_line * 3}}" - # Format the ASCII values - ascii_values = "".join(to_printable_ascii(byte) for byte in chunk) - lines.append(f"{offset:08x} {padded_hex} |{ascii_values}|") - offset += bytes_per_line - return lines - - @pytest.fixture def deterministic_message_fixtures() -> Generator[None, None, None]: """Fixture to use predictable get_next_int and timestamp values for each test.