diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index bbdc2fae..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,532 +0,0 @@ -import asyncio -import io -import logging -import re -from asyncio import Protocol -from collections.abc import AsyncGenerator, Callable, Generator -from queue import Queue -from typing import Any -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -import pytest -from aioresponses import aioresponses - -from roborock import HomeData, UserData -from roborock.data import DeviceData -from roborock.mqtt.health_manager import HealthManager -from roborock.protocols.v1_protocol import LocalProtocolVersion -from roborock.roborock_message import RoborockMessage -from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 -from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1 -from tests.mock_data import HOME_DATA_RAW, HOME_DATA_SCENES_RAW, TEST_LOCAL_API_HOST, USER_DATA - -# Fixtures for the newer APIs in subdirectories -pytest_plugins = [ - "tests.mqtt_fixtures", -] - -_LOGGER = logging.getLogger(__name__) - - -# Used by fixtures to handle incoming requests and prepare responses -RequestHandler = Callable[[bytes], bytes | None] -QUEUE_TIMEOUT = 10 - -# Fixed timestamp for deterministic tests for asserting on message contents -FAKE_TIMESTAMP = 1755750946.721395 - - -class CapturedRequestLog: - """Log of requests and responses for snapshot assertions. - - The log captures the raw bytes of each request and response along with - a label indicating the direction of the message. - """ - - def __init__(self) -> None: - """Initialize the request log.""" - self.entries: list[tuple[str, bytes]] = [] - - def add_log_entry(self, label: str, data: bytes) -> None: - """Add a request entry.""" - self.entries.append((label, data)) - - def __repr__(self): - """Return a string representation of the log entries. - - This assumes that the client will behave in a request-response manner, - so each request is followed by a response. If a test uses non-deterministic - message order, this may not be accurate and the test would need to decode - the raw messages and remove any ordering assumptions. - """ - lines = [] - for label, data in self.entries: - lines.append(label) - lines.extend(self._hexdump(data)) - return "\n".join(lines) - - def _hexdump(self, data: bytes, bytes_per_line: int = 16) -> list[str]: - """Print a hexdump of the given bytes object in a tcpdump/hexdump -C style. - - This makes the packets easier to read and compare in test snapshots. - - Args: - data: The bytes object to print. - bytes_per_line: The number of bytes to display per line (default is 16). - """ - - # Use '.' for non-printable characters (ASCII < 32 or > 126) - def to_printable_ascii(byte_val): - return chr(byte_val) if 32 <= byte_val <= 126 else "." - - offset = 0 - lines = [] - while offset < len(data): - chunk = data[offset : offset + bytes_per_line] - # Format the hex values, space-padded to ensure alignment - hex_values = " ".join(f"{byte:02x}" for byte in chunk) - # Pad hex string to a fixed width so ASCII column lines up - # 3 chars per byte ('xx ') for a full line of 16 bytes - padded_hex = f"{hex_values:<{bytes_per_line * 3}}" - # Format the ASCII values - ascii_values = "".join(to_printable_ascii(byte) for byte in chunk) - lines.append(f"{offset:08x} {padded_hex} |{ascii_values}|") - offset += bytes_per_line - return lines - - -@pytest.fixture -def deterministic_message_fixtures() -> Generator[None, None, None]: - """Fixture to use predictable get_next_int and timestamp values for each test. - - This test mocks out the functions used to generate requests that have some - entropy such as the nonces, timestamps, and request IDs. This makes the - generated messages deterministic so we can snapshot them in a test. - """ - - # Pick an arbitrary sequence number used for outgoing requests - next_int = 9090 - - def get_next_int(min_value: int, max_value: int) -> int: - nonlocal next_int - result = next_int - next_int += 1 - if next_int > max_value: - next_int = min_value - return result - - # Pick an arbitrary timestamp used for the message encryption - timestamp = FAKE_TIMESTAMP - - def get_timestamp() -> int: - """Get a monotonically increasing timestamp for testing.""" - nonlocal timestamp - timestamp += 1 - return int(timestamp) - - # Use predictable seeds for token_bytes - token_chr = "A" - - def get_token_bytes(n: int) -> bytes: - nonlocal token_chr - result = token_chr.encode() * n - # Cycle to the next character - token_chr = chr(ord(token_chr) + 1) - if token_chr > "Z": - token_chr = "A" - return result - - with ( - patch("roborock.api.get_next_int", side_effect=get_next_int), - patch("roborock.devices.local_channel.get_next_int", side_effect=get_next_int), - patch("roborock.protocols.v1_protocol.get_next_int", side_effect=get_next_int), - patch("roborock.protocols.v1_protocol.get_timestamp", side_effect=get_timestamp), - patch("roborock.protocols.v1_protocol.secrets.token_bytes", side_effect=get_token_bytes), - patch("roborock.version_1_apis.roborock_local_client_v1.get_next_int", side_effect=get_next_int), - patch("roborock.roborock_message.get_next_int", side_effect=get_next_int), - patch("roborock.roborock_message.get_timestamp", side_effect=get_timestamp), - ): - yield - - -@pytest.fixture(name="log") -def log_fixture(deterministic_message_fixtures: None) -> CapturedRequestLog: - """Fixture that creates a captured request log.""" - return CapturedRequestLog() - - -class FakeSocketHandler: - """Fake socket used by the test to simulate a connection to the broker. - - The socket handler is used to intercept the socket send and recv calls and - populate the response buffer with data to be sent back to the client. The - handle request callback handles the incoming requests and prepares the responses. - """ - - def __init__(self, handle_request: RequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog) -> None: - self.response_buf = io.BytesIO() - self.handle_request = handle_request - self.response_queue = response_queue - self.log = log - - def pending(self) -> int: - """Return the number of bytes in the response buffer.""" - return len(self.response_buf.getvalue()) - - def handle_socket_recv(self, read_size: int) -> bytes: - """Intercept a client recv() and populate the buffer.""" - if self.pending() == 0: - raise BlockingIOError("No response queued") - - self.response_buf.seek(0) - data = self.response_buf.read(read_size) - _LOGGER.debug("Response: 0x%s", data.hex()) - # Consume the rest of the data in the buffer - remaining_data = self.response_buf.read() - self.response_buf = io.BytesIO(remaining_data) - return data - - def handle_socket_send(self, client_request: bytes) -> int: - """Receive an incoming request from the client.""" - _LOGGER.debug("Request: 0x%s", client_request.hex()) - self.log.add_log_entry("[mqtt >]", client_request) - if (response := self.handle_request(client_request)) is not None: - # Enqueue a response to be sent back to the client in the buffer. - # The buffer will be emptied when the client calls recv() on the socket - _LOGGER.debug("Queued: 0x%s", response.hex()) - self.log.add_log_entry("[mqtt <]", response) - self.response_buf.write(response) - return len(client_request) - - def push_response(self) -> None: - """Push a response to the client.""" - if not self.response_queue.empty(): - response = self.response_queue.get() - # Enqueue a response to be sent back to the client in the buffer. - # The buffer will be emptied when the client calls recv() on the socket - _LOGGER.debug("Queued: 0x%s", response.hex()) - self.response_buf.write(response) - - -@pytest.fixture(name="received_requests") -def received_requests_fixture() -> Queue[bytes]: - """Fixture that provides access to the received requests.""" - return Queue() - - -@pytest.fixture(name="response_queue") -def response_queue_fixture() -> Generator[Queue[bytes], None, None]: - """Fixture that provides access to the received requests.""" - response_queue: Queue[bytes] = Queue() - yield response_queue - assert response_queue.empty(), "Not all fake responses were consumed" - - -@pytest.fixture(name="request_handler") -def request_handler_fixture(received_requests: Queue[bytes], response_queue: Queue[bytes]) -> RequestHandler: - """Fixture records incoming requests and replies with responses from the queue.""" - - def handle_request(client_request: bytes) -> bytes | None: - """Handle an incoming request from the client.""" - received_requests.put(client_request) - - # Insert a prepared response into the response buffer - if not response_queue.empty(): - return response_queue.get() - return None - - return handle_request - - -@pytest.fixture(name="fake_socket_handler") -def fake_socket_handler_fixture( - request_handler: RequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog -) -> FakeSocketHandler: - """Fixture that creates a fake MQTT broker.""" - return FakeSocketHandler(request_handler, response_queue, log) - - -@pytest.fixture(name="mock_sock") -def mock_sock_fixture(fake_socket_handler: FakeSocketHandler) -> Mock: - """Fixture that creates a mock socket connection and wires it to the handler.""" - mock_sock = Mock() - mock_sock.recv = fake_socket_handler.handle_socket_recv - mock_sock.send = fake_socket_handler.handle_socket_send - mock_sock.pending = fake_socket_handler.pending - return mock_sock - - -@pytest.fixture(name="mock_create_connection") -def create_connection_fixture(mock_sock: Mock) -> Generator[None, None, None]: - """Fixture that overrides the MQTT socket creation to wire it up to the mock socket.""" - with patch("paho.mqtt.client.socket.create_connection", return_value=mock_sock): - yield - - -@pytest.fixture(name="mock_select") -def select_fixture(mock_sock: Mock, fake_socket_handler: FakeSocketHandler) -> Generator[None, None, None]: - """Fixture that overrides the MQTT client select calls to make select work on the mock socket. - - This patch select to activate our mock socket when ready with data. Internal mqtt sockets are - always ready since they are used internally to wake the select loop. Ours is ready if there - is data in the buffer. - """ - - def is_ready(sock: Any) -> bool: - return sock is not mock_sock or (fake_socket_handler.pending() > 0) - - def handle_select(rlist: list, wlist: list, *args: Any) -> list: - return [list(filter(is_ready, rlist)), list(filter(is_ready, wlist))] - - with patch("paho.mqtt.client.select.select", side_effect=handle_select): - yield - - -@pytest.fixture(name="mqtt_client") -async def mqtt_client(mock_create_connection: None, mock_select: None) -> AsyncGenerator[RoborockMqttClientV1, None]: - user_data = UserData.from_dict(USER_DATA) - home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = DeviceData( - device=home_data.devices[0], - model=home_data.products[0].model, - ) - client = RoborockMqttClientV1(user_data, device_info, queue_timeout=QUEUE_TIMEOUT) - try: - yield client - finally: - if not client.is_connected(): - try: - await client.async_release() - except Exception: - pass - - -@pytest.fixture(name="mock_rest", autouse=True) -def mock_rest() -> aioresponses: - """Mock all rest endpoints so they won't hit real endpoints""" - with aioresponses() as mocked: - # Match the base URL and allow any query params - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"), - status=200, - payload={ - "code": 200, - "data": {"country": "US", "countrycode": "1", "url": "https://usiot.roborock.com"}, - "msg": "success", - }, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/login.*"), - status=200, - payload={"code": 200, "data": USER_DATA, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/loginWithCode.*"), - status=200, - payload={"code": 200, "data": USER_DATA, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v1/sendEmailCode.*"), - status=200, - payload={"code": 200, "data": None, "msg": "success"}, - ) - mocked.get( - re.compile(r"https://.*iot\.roborock\.com/api/v1/getHomeDetail.*"), - status=200, - payload={ - "code": 200, - "data": {"deviceListOrder": None, "id": 123456, "name": "My Home", "rrHomeId": 123456, "tuyaHomeId": 0}, - "msg": "success", - }, - ) - mocked.get( - re.compile(r"https://api-.*\.roborock\.com/v2/user/homes*"), - status=200, - payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True}, - ) - mocked.post( - re.compile(r"https://api-.*\.roborock\.com/nc/prepare"), - status=200, - payload={ - "api": None, - "result": {"r": "US", "s": "ffffff", "t": "eOf6d2BBBB"}, - "status": "ok", - "success": True, - }, - ) - - mocked.get( - re.compile(r"https://api-.*\.roborock\.com/user/devices/newadd/*"), - status=200, - payload={ - "api": "获取新增设备信息", - "result": { - "activeTime": 1737724598, - "attribute": None, - "cid": None, - "createTime": 0, - "deviceStatus": None, - "duid": "rand_duid", - "extra": "{}", - "f": False, - "featureSet": "0", - "fv": "02.16.12", - "iconUrl": "", - "lat": None, - "localKey": "random_lk", - "lon": None, - "name": "S7", - "newFeatureSet": "0000000000002000", - "online": True, - "productId": "rand_prod_id", - "pv": "1.0", - "roomId": None, - "runtimeEnv": None, - "setting": None, - "share": False, - "shareTime": None, - "silentOtaSwitch": False, - "sn": "Rand_sn", - "timeZoneId": "America/New_York", - "tuyaMigrated": False, - "tuyaUuid": None, - }, - "status": "ok", - "success": True, - }, - ) - mocked.get( - re.compile(r"https://api-.*\.roborock\.com/user/scene/device/.*"), - status=200, - payload={"api": None, "code": 200, "result": HOME_DATA_SCENES_RAW, "status": "ok", "success": True}, - ) - mocked.post( - re.compile(r"https://api-.*\.roborock\.com/user/scene/.*/execute"), - status=200, - payload={"api": None, "code": 200, "result": None, "status": "ok", "success": True}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v4/email/code/send.*"), - status=200, - payload={"code": 200, "data": None, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v3/key/sign.*"), - status=200, - payload={"code": 200, "data": {"k": "mock_k"}, "msg": "success"}, - ) - mocked.post( - re.compile(r"https://.*iot\.roborock\.com/api/v4/auth/email/login/code.*"), - status=200, - payload={"code": 200, "data": USER_DATA, "msg": "success"}, - ) - yield mocked - - -@pytest.fixture(autouse=True) -def skip_rate_limit(): - """Don't rate limit tests as they aren't actually hitting the api.""" - with ( - patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire"), - patch("roborock.web_api.RoborockApiClient._home_data_limiter.try_acquire"), - ): - yield - - -@pytest.fixture(name="mock_create_local_connection") -def create_local_connection_fixture( - request_handler: RequestHandler, log: CapturedRequestLog -) -> Generator[None, None, None]: - """Fixture that overrides the transport creation to wire it up to the mock socket.""" - - async def create_connection(protocol_factory: Callable[[], Protocol], *args) -> tuple[Any, Any]: - protocol = protocol_factory() - - def handle_write(data: bytes) -> None: - _LOGGER.debug("Received: %s", data) - response = request_handler(data) - log.add_log_entry("[local >]", data) - if response is not None: - _LOGGER.debug("Replying with %s", response) - log.add_log_entry("[local <]", response) - loop = asyncio.get_running_loop() - loop.call_soon(protocol.data_received, response) - - closed = asyncio.Event() - - mock_transport = Mock() - mock_transport.write = handle_write - mock_transport.close = closed.set - mock_transport.is_reading = lambda: not closed.is_set() - - return (mock_transport, "proto") - - with patch("roborock.version_1_apis.roborock_local_client_v1.get_running_loop") as mock_loop: - mock_loop.return_value.create_connection.side_effect = create_connection - yield - - -@pytest.fixture(name="local_client") -async def local_client_fixture(mock_create_local_connection: None) -> AsyncGenerator[RoborockLocalClientV1, None]: - home_data = HomeData.from_dict(HOME_DATA_RAW) - device_info = DeviceData( - device=home_data.devices[0], - model=home_data.products[0].model, - host=TEST_LOCAL_API_HOST, - ) - client = RoborockLocalClientV1(device_info, queue_timeout=QUEUE_TIMEOUT) - try: - yield client - finally: - if not client.is_connected(): - try: - await client.async_release() - except Exception: - pass - - -class FakeChannel: - """A fake channel that handles publish and subscribe calls.""" - - def __init__(self): - """Initialize the fake channel.""" - self.subscribers: list[Callable[[RoborockMessage], None]] = [] - self.published_messages: list[RoborockMessage] = [] - self.response_queue: list[RoborockMessage] = [] - self._is_connected = False - self.publish_side_effect: Exception | None = None - self.publish = AsyncMock(side_effect=self._publish) - self.subscribe = AsyncMock(side_effect=self._subscribe) - self.connect = AsyncMock(side_effect=self._connect) - self.close = MagicMock(side_effect=self._close) - self.protocol_version = LocalProtocolVersion.V1 - self.restart = AsyncMock() - self.health_manager = HealthManager(self.restart) - - async def _connect(self) -> None: - self._is_connected = True - - def _close(self) -> None: - self._is_connected = False - - @property - def is_connected(self) -> bool: - """Return true if connected.""" - return self._is_connected - - async def _publish(self, message: RoborockMessage) -> None: - """Simulate publishing a message and triggering a response.""" - self.published_messages.append(message) - if self.publish_side_effect: - raise self.publish_side_effect - # When a message is published, simulate a response - if self.response_queue: - response = self.response_queue.pop(0) - # Give a chance for the subscriber to be registered - for subscriber in list(self.subscribers): - subscriber(response) - - async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: - """Simulate subscribing to messages.""" - self.subscribers.append(callback) - return lambda: self.subscribers.remove(callback) diff --git a/tests/devices/test_a01_channel.py b/tests/devices/test_a01_channel.py index e5bf112a..befa8951 100644 --- a/tests/devices/test_a01_channel.py +++ b/tests/devices/test_a01_channel.py @@ -11,8 +11,7 @@ RoborockMessage, RoborockMessageProtocol, ) - -from ..conftest import FakeChannel +from tests.fixtures.channel_fixtures import FakeChannel @pytest.fixture diff --git a/tests/devices/test_v1_channel.py b/tests/devices/test_v1_channel.py index 1152675f..c9c40cad 100644 --- a/tests/devices/test_v1_channel.py +++ b/tests/devices/test_v1_channel.py @@ -25,9 +25,8 @@ from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from roborock.roborock_typing import RoborockCommand - -from .. import mock_data -from ..conftest import FakeChannel +from tests import mock_data +from tests.fixtures.channel_fixtures import FakeChannel USER_DATA = UserData.from_dict(mock_data.USER_DATA) TEST_DEVICE_UID = "abc123" diff --git a/tests/devices/traits/a01/test_init.py b/tests/devices/traits/a01/test_init.py index 9b9c83b3..8e1cb7dd 100644 --- a/tests/devices/traits/a01/test_init.py +++ b/tests/devices/traits/a01/test_init.py @@ -8,7 +8,7 @@ from roborock.devices.traits.a01 import DyadApi, ZeoApi from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol -from tests.conftest import FakeChannel +from tests.fixtures.channel_fixtures import FakeChannel from tests.protocols.common import build_a01_message diff --git a/tests/devices/traits/b01/test_init.py b/tests/devices/traits/b01/test_init.py index 7fa89f48..0eb0f5f2 100644 --- a/tests/devices/traits/b01/test_init.py +++ b/tests/devices/traits/b01/test_init.py @@ -12,7 +12,7 @@ from roborock.exceptions import RoborockException from roborock.protocols.b01_protocol import B01_VERSION from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol -from tests.conftest import FakeChannel +from tests.fixtures.channel_fixtures import FakeChannel def build_b01_message(message: dict[Any, Any], msg_id: str = "123456789", seq: int = 2020) -> RoborockMessage: diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py index c50cd2b3..38fd7fb6 100644 --- a/tests/e2e/__init__.py +++ b/tests/e2e/__init__.py @@ -1 +1,8 @@ """End-to-end tests package.""" + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", + "tests.fixtures.local_async_fixtures", + "tests.fixtures.pahomqtt_fixtures", + "tests.fixtures.aiomqtt_fixtures", +] diff --git a/tests/e2e/test_local_session.py b/tests/e2e/test_local_session.py index 392279ff..347923da 100644 --- a/tests/e2e/test_local_session.py +++ b/tests/e2e/test_local_session.py @@ -1,17 +1,14 @@ """End-to-end tests for LocalChannel using fake sockets.""" import asyncio -from collections.abc import AsyncGenerator, Callable, Generator -from queue import Queue -from typing import Any -from unittest.mock import Mock, patch +from collections.abc import AsyncGenerator +from unittest.mock import patch import pytest from roborock.devices.local_channel import LocalChannel from roborock.protocol import create_local_decoder, create_local_encoder from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -from tests.conftest import RequestHandler from tests.mock_data import LOCAL_KEY TEST_HOST = "192.168.1.100" @@ -21,35 +18,8 @@ TEST_RANDOM = 13579 -@pytest.fixture(name="mock_create_local_connection") -def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]: - """Fixture that overrides the transport creation to wire it up to the mock socket.""" - - async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]: - protocol = protocol_factory() - - def handle_write(data: bytes) -> None: - response = request_handler(data) - if response is not None: - # Call data_received directly to avoid loop scheduling issues in test - protocol.data_received(response) - - closed = asyncio.Event() - - mock_transport = Mock() - mock_transport.write = handle_write - mock_transport.close = closed.set - mock_transport.is_reading = lambda: not closed.is_set() - - return (mock_transport, protocol) - - with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop: - mock_loop.return_value.create_connection.side_effect = create_connection - yield - - @pytest.fixture(name="local_channel") -async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]: +async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]: with patch( "roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID ): @@ -80,19 +50,23 @@ def build_response( async def test_connect( - local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes] + local_channel: LocalChannel, + local_response_queue: asyncio.Queue[bytes], + local_received_requests: asyncio.Queue[bytes], ) -> None: """Test connecting to the device.""" # Queue HELLO response with payload to ensure it can be parsed - response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)) + local_response_queue.put_nowait( + build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM) + ) await local_channel.connect() assert local_channel.is_connected - assert received_requests.qsize() == 1 + assert local_received_requests.qsize() == 1 # Verify HELLO request - request_bytes = received_requests.get() + request_bytes = await local_received_requests.get() # Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None # which causes MessageParser to fail parsing. For now we verify the raw bytes. @@ -104,17 +78,21 @@ async def test_connect( async def test_send_command( - local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes] + local_channel: LocalChannel, + local_response_queue: asyncio.Queue[bytes], + local_received_requests: asyncio.Queue[bytes], ) -> None: """Test sending a command.""" # Queue HELLO response - response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)) + local_response_queue.put_nowait( + build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM) + ) await local_channel.connect() # Clear requests from handshake - while not received_requests.empty(): - received_requests.get() + while not local_received_requests.empty(): + await local_received_requests.get() # Send command cmd_seq = 123 @@ -127,8 +105,8 @@ async def test_send_command( await local_channel.publish(msg) # Verify request - assert received_requests.qsize() == 1 - request_bytes = received_requests.get() + request_bytes = await local_received_requests.get() + assert local_received_requests.empty() # Decode request decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE) diff --git a/tests/e2e/test_mqtt_session.py b/tests/e2e/test_mqtt_session.py index e8a63843..294bf5e8 100644 --- a/tests/e2e/test_mqtt_session.py +++ b/tests/e2e/test_mqtt_session.py @@ -18,12 +18,12 @@ from roborock.protocol import MessageParser from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from tests import mqtt_packet +from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber from tests.mock_data import LOCAL_KEY -from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber @pytest.fixture(autouse=True) -def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None: +def auto_mock_mqtt_client(mock_aiomqtt_client: None) -> None: """Automatically use the mock mqtt client fixture.""" @@ -33,7 +33,7 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: @pytest.fixture(autouse=True) -def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: +def mqtt_server_fixture(mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None) -> None: """Fixture to mock the MQTT connection. This is here to pull in the mock socket pixtures into all tests used here. @@ -42,10 +42,10 @@ def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None @pytest.fixture(name="session") async def session_fixture( - push_response: Callable[[bytes], None], + push_mqtt_response: Callable[[bytes], None], ) -> AsyncGenerator[MqttSession, None]: """Fixture to create a new connected MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected try: @@ -54,12 +54,12 @@ async def session_fixture( await session.close() -async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None: +async def test_session_e2e_receive_message(push_mqtt_response: Callable[[bytes], None], session: MqttSession) -> None: """Test receiving a real Roborock message through the session.""" assert session.connected # Subscribe to the topic. We'll next construct and push a message. - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber = Subscriber() await session.subscribe("topic-1", subscriber.append) @@ -71,12 +71,13 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False) # Simulate receiving the message from the broker - push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload)) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload)) # Verify it was dispatched to the subscriber await subscriber.wait() assert len(subscriber.messages) == 1 received_payload = subscriber.messages[0] + assert isinstance(received_payload, bytes) assert received_payload == payload # Verify the message payload contents @@ -90,8 +91,8 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None async def test_session_e2e_publish_message( - push_response: Callable[[bytes], None], - received_requests: Queue, + push_mqtt_response: Callable[[bytes], None], + mqtt_received_requests: Queue, session: MqttSession, ) -> None: """Test publishing a real Roborock message.""" @@ -109,8 +110,8 @@ async def test_session_e2e_publish_message( # Verify what was sent to the broker # We expect the payload to be present in the sent bytes found = False - while not received_requests.empty(): - request = received_requests.get() + while not mqtt_received_requests.empty(): + request = mqtt_received_requests.get() if payload in request: found = True break diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mqtt_fixtures.py b/tests/fixtures/aiomqtt_fixtures.py similarity index 67% rename from tests/mqtt_fixtures.py rename to tests/fixtures/aiomqtt_fixtures.py index cb5ff17f..d9e10e74 100644 --- a/tests/mqtt_fixtures.py +++ b/tests/fixtures/aiomqtt_fixtures.py @@ -10,40 +10,11 @@ import paho.mqtt.client as mqtt import pytest -from roborock.mqtt.session import MqttParams -from tests.conftest import FakeSocketHandler +from .mqtt import FakeMqttSocketHandler -FAKE_PARAMS = MqttParams( - host="localhost", - port=1883, - tls=False, - username="username", - password="password", - timeout=10.0, -) - -class Subscriber: - """Mock subscriber class. - - We use this to hold on to received messages for verification. - """ - - def __init__(self) -> None: - self.messages: list[bytes] = [] - self._event = asyncio.Event() - - def append(self, message: bytes) -> None: - self.messages.append(message) - self._event.set() - - async def wait(self) -> None: - await asyncio.wait_for(self._event.wait(), timeout=1.0) - self._event.clear() - - -@pytest.fixture -async def mock_mqtt_client_fixture() -> AsyncGenerator[None, None]: +@pytest.fixture(name="mock_aiomqtt_client") +async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]: """Fixture to patch the MQTT underlying sync client. The tests use fake sockets, so this ensures that the async mqtt client does not @@ -94,11 +65,13 @@ def fast_backoff_fixture() -> Generator[None, None, None]: @pytest.fixture -def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]: +def push_mqtt_response( + mqtt_response_queue: Queue, fake_mqtt_socket_handler: FakeMqttSocketHandler +) -> Callable[[bytes], None]: """Fixture to push a response to the client.""" def _push(data: bytes) -> None: - response_queue.put(data) - fake_socket_handler.push_response() + mqtt_response_queue.put(data) + fake_mqtt_socket_handler.push_response() return _push diff --git a/tests/fixtures/channel_fixtures.py b/tests/fixtures/channel_fixtures.py new file mode 100644 index 00000000..1faae11c --- /dev/null +++ b/tests/fixtures/channel_fixtures.py @@ -0,0 +1,53 @@ +from collections.abc import Callable +from unittest.mock import AsyncMock, MagicMock + +from roborock.mqtt.health_manager import HealthManager +from roborock.protocols.v1_protocol import LocalProtocolVersion +from roborock.roborock_message import RoborockMessage + + +class FakeChannel: + """A fake channel that handles publish and subscribe calls.""" + + def __init__(self): + """Initialize the fake channel.""" + self.subscribers: list[Callable[[RoborockMessage], None]] = [] + self.published_messages: list[RoborockMessage] = [] + self.response_queue: list[RoborockMessage] = [] + self._is_connected = False + self.publish_side_effect: Exception | None = None + self.publish = AsyncMock(side_effect=self._publish) + self.subscribe = AsyncMock(side_effect=self._subscribe) + self.connect = AsyncMock(side_effect=self._connect) + self.close = MagicMock(side_effect=self._close) + self.protocol_version = LocalProtocolVersion.V1 + self.restart = AsyncMock() + self.health_manager = HealthManager(self.restart) + + async def _connect(self) -> None: + self._is_connected = True + + def _close(self) -> None: + self._is_connected = False + + @property + def is_connected(self) -> bool: + """Return true if connected.""" + return self._is_connected + + async def _publish(self, message: RoborockMessage) -> None: + """Simulate publishing a message and triggering a response.""" + self.published_messages.append(message) + if self.publish_side_effect: + raise self.publish_side_effect + # When a message is published, simulate a response + if self.response_queue: + response = self.response_queue.pop(0) + # Give a chance for the subscriber to be registered + for subscriber in list(self.subscribers): + subscriber(response) + + async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: + """Simulate subscribing to messages.""" + self.subscribers.append(callback) + return lambda: self.subscribers.remove(callback) diff --git a/tests/fixtures/local_async_fixtures.py b/tests/fixtures/local_async_fixtures.py new file mode 100644 index 00000000..394dbca5 --- /dev/null +++ b/tests/fixtures/local_async_fixtures.py @@ -0,0 +1,82 @@ +import asyncio +import logging +import warnings +from collections.abc import Awaitable, Callable, Generator +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +_LOGGER = logging.getLogger(__name__) + +AsyncLocalRequestHandler = Callable[[bytes], Awaitable[bytes | None]] + + +@pytest.fixture(name="local_received_requests") +def received_requests_fixture() -> asyncio.Queue[bytes]: + """Fixture that provides access to the received requests.""" + return asyncio.Queue() + + +@pytest.fixture(name="local_response_queue") +def response_queue_fixture() -> Generator[asyncio.Queue[bytes], None, None]: + """Fixture that provides a queue of responses to be sent to the client.""" + response_queue: asyncio.Queue[bytes] = asyncio.Queue() + yield response_queue + if not response_queue.empty(): + warnings.warn("Some enqueued local device responses were not consumed during the test") + + +@pytest.fixture(name="local_async_request_handler") +def local_request_handler_fixture( + local_received_requests: asyncio.Queue[bytes], local_response_queue: asyncio.Queue[bytes] +) -> AsyncLocalRequestHandler: + """Fixture records incoming requests and replies with responses from the queue.""" + + async def handle_request(client_request: bytes) -> bytes | None: + """Handle an incoming request from the client.""" + local_received_requests.put_nowait(client_request) + + # Insert a prepared response into the response buffer + if not local_response_queue.empty(): + return await local_response_queue.get() + return None + + return handle_request + + +@pytest.fixture(name="mock_async_create_local_connection") +def create_local_connection_fixture( + local_async_request_handler: AsyncLocalRequestHandler, +) -> Generator[None, None, None]: + """Fixture that overrides the transport creation to wire it up to the mock socket.""" + + tasks = [] + + async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]: + protocol = protocol_factory() + + async def handle_write(data: bytes) -> None: + response = await local_async_request_handler(data) + if response is not None: + # Call data_received directly to avoid loop scheduling issues in test + protocol.data_received(response) + + def start_handle_write(data: bytes) -> None: + tasks.append(asyncio.create_task(handle_write(data))) + + closed = asyncio.Event() + + mock_transport = Mock() + mock_transport.write = start_handle_write + mock_transport.close = closed.set + mock_transport.is_reading = lambda: not closed.is_set() + + return (mock_transport, protocol) + + with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop: + mock_loop.return_value.create_connection.side_effect = create_connection + yield + + for task in tasks: + task.cancel() diff --git a/tests/fixtures/logging.py b/tests/fixtures/logging.py new file mode 100644 index 00000000..2d0279a6 --- /dev/null +++ b/tests/fixtures/logging.py @@ -0,0 +1,60 @@ +"""Logging utilities for tests.""" + + +class CapturedRequestLog: + """Log of requests and responses for snapshot assertions. + + The log captures the raw bytes of each request and response along with + a label indicating the direction of the message. + """ + + def __init__(self) -> None: + """Initialize the request log.""" + self.entries: list[tuple[str, bytes]] = [] + + def add_log_entry(self, label: str, data: bytes) -> None: + """Add a request entry.""" + self.entries.append((label, data)) + + def __repr__(self): + """Return a string representation of the log entries. + + This assumes that the client will behave in a request-response manner, + so each request is followed by a response. If a test uses non-deterministic + message order, this may not be accurate and the test would need to decode + the raw messages and remove any ordering assumptions. + """ + lines = [] + for label, data in self.entries: + lines.append(label) + lines.extend(self._hexdump(data)) + return "\n".join(lines) + + def _hexdump(self, data: bytes, bytes_per_line: int = 16) -> list[str]: + """Print a hexdump of the given bytes object in a tcpdump/hexdump -C style. + + This makes the packets easier to read and compare in test snapshots. + + Args: + data: The bytes object to print. + bytes_per_line: The number of bytes to display per line (default is 16). + """ + + # Use '.' for non-printable characters (ASCII < 32 or > 126) + def to_printable_ascii(byte_val): + return chr(byte_val) if 32 <= byte_val <= 126 else "." + + offset = 0 + lines = [] + while offset < len(data): + chunk = data[offset : offset + bytes_per_line] + # Format the hex values, space-padded to ensure alignment + hex_values = " ".join(f"{byte:02x}" for byte in chunk) + # Pad hex string to a fixed width so ASCII column lines up + # 3 chars per byte ('xx ') for a full line of 16 bytes + padded_hex = f"{hex_values:<{bytes_per_line * 3}}" + # Format the ASCII values + ascii_values = "".join(to_printable_ascii(byte) for byte in chunk) + lines.append(f"{offset:08x} {padded_hex} |{ascii_values}|") + offset += bytes_per_line + return lines diff --git a/tests/fixtures/logging_fixtures.py b/tests/fixtures/logging_fixtures.py new file mode 100644 index 00000000..feb9e82f --- /dev/null +++ b/tests/fixtures/logging_fixtures.py @@ -0,0 +1,69 @@ +from collections.abc import Generator +from unittest.mock import patch + +import pytest + +from .logging import CapturedRequestLog + +# Fixed timestamp for deterministic tests for asserting on message contents +FAKE_TIMESTAMP = 1755750946.721395 + + +@pytest.fixture +def deterministic_message_fixtures() -> Generator[None, None, None]: + """Fixture to use predictable get_next_int and timestamp values for each test. + + This test mocks out the functions used to generate requests that have some + entropy such as the nonces, timestamps, and request IDs. This makes the + generated messages deterministic so we can snapshot them in a test. + """ + + # Pick an arbitrary sequence number used for outgoing requests + next_int = 9090 + + def get_next_int(min_value: int, max_value: int) -> int: + nonlocal next_int + result = next_int + next_int += 1 + if next_int > max_value: + next_int = min_value + return result + + # Pick an arbitrary timestamp used for the message encryption + timestamp = FAKE_TIMESTAMP + + def get_timestamp() -> int: + """Get a monotonically increasing timestamp for testing.""" + nonlocal timestamp + timestamp += 1 + return int(timestamp) + + # Use predictable seeds for token_bytes + token_chr = "A" + + def get_token_bytes(n: int) -> bytes: + nonlocal token_chr + result = token_chr.encode() * n + # Cycle to the next character + token_chr = chr(ord(token_chr) + 1) + if token_chr > "Z": + token_chr = "A" + return result + + with ( + patch("roborock.api.get_next_int", side_effect=get_next_int), + patch("roborock.devices.local_channel.get_next_int", side_effect=get_next_int), + patch("roborock.protocols.v1_protocol.get_next_int", side_effect=get_next_int), + patch("roborock.protocols.v1_protocol.get_timestamp", side_effect=get_timestamp), + patch("roborock.protocols.v1_protocol.secrets.token_bytes", side_effect=get_token_bytes), + patch("roborock.version_1_apis.roborock_local_client_v1.get_next_int", side_effect=get_next_int), + patch("roborock.roborock_message.get_next_int", side_effect=get_next_int), + patch("roborock.roborock_message.get_timestamp", side_effect=get_timestamp), + ): + yield + + +@pytest.fixture(name="log") +def log_fixture(deterministic_message_fixtures: None) -> CapturedRequestLog: + """Fixture that creates a captured request log.""" + return CapturedRequestLog() diff --git a/tests/fixtures/mqtt.py b/tests/fixtures/mqtt.py new file mode 100644 index 00000000..793bb3f4 --- /dev/null +++ b/tests/fixtures/mqtt.py @@ -0,0 +1,101 @@ +"""Common code for MQTT tests.""" + +import asyncio +import io +import logging +from collections.abc import Callable +from queue import Queue + +from roborock.mqtt.session import MqttParams +from roborock.roborock_message import RoborockMessage + +from .logging import CapturedRequestLog + +_LOGGER = logging.getLogger(__name__) + +# Used by fixtures to handle incoming requests and prepare responses +MqttRequestHandler = Callable[[bytes], bytes | None] + + +class FakeMqttSocketHandler: + """Fake socket used by the test to simulate a connection to the broker. + + The socket handler is used to intercept the socket send and recv calls and + populate the response buffer with data to be sent back to the client. The + handle request callback handles the incoming requests and prepares the responses. + """ + + def __init__( + self, handle_request: MqttRequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog + ) -> None: + self.response_buf = io.BytesIO() + self.handle_request = handle_request + self.response_queue = response_queue + self.log = log + + def pending(self) -> int: + """Return the number of bytes in the response buffer.""" + return len(self.response_buf.getvalue()) + + def handle_socket_recv(self, read_size: int) -> bytes: + """Intercept a client recv() and populate the buffer.""" + if self.pending() == 0: + raise BlockingIOError("No response queued") + + self.response_buf.seek(0) + data = self.response_buf.read(read_size) + _LOGGER.debug("Response: 0x%s", data.hex()) + # Consume the rest of the data in the buffer + remaining_data = self.response_buf.read() + self.response_buf = io.BytesIO(remaining_data) + return data + + def handle_socket_send(self, client_request: bytes) -> int: + """Receive an incoming request from the client.""" + _LOGGER.debug("Request: 0x%s", client_request.hex()) + self.log.add_log_entry("[mqtt >]", client_request) + if (response := self.handle_request(client_request)) is not None: + # Enqueue a response to be sent back to the client in the buffer. + # The buffer will be emptied when the client calls recv() on the socket + _LOGGER.debug("Queued: 0x%s", response.hex()) + self.log.add_log_entry("[mqtt <]", response) + self.response_buf.write(response) + return len(client_request) + + def push_response(self) -> None: + """Push a response to the client.""" + if not self.response_queue.empty(): + response = self.response_queue.get() + # Enqueue a response to be sent back to the client in the buffer. + # The buffer will be emptied when the client calls recv() on the socket + _LOGGER.debug("Queued: 0x%s", response.hex()) + self.response_buf.write(response) + + +FAKE_PARAMS = MqttParams( + host="localhost", + port=1883, + tls=False, + username="username", + password="password", + timeout=10.0, +) + + +class Subscriber: + """Mock subscriber class. + + We use this to hold on to received messages for verification. + """ + + def __init__(self) -> None: + self.messages: list[RoborockMessage | bytes] = [] + self._event = asyncio.Event() + + def append(self, message: RoborockMessage | bytes) -> None: + self.messages.append(message) + self._event.set() + + async def wait(self) -> None: + await asyncio.wait_for(self._event.wait(), timeout=1.0) + self._event.clear() diff --git a/tests/fixtures/pahomqtt_fixtures.py b/tests/fixtures/pahomqtt_fixtures.py new file mode 100644 index 00000000..97655f3d --- /dev/null +++ b/tests/fixtures/pahomqtt_fixtures.py @@ -0,0 +1,97 @@ +"""Common code for MQTT tests.""" + +import logging +from collections.abc import Callable, Generator +from queue import Queue +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from .logging import CapturedRequestLog +from .mqtt import FakeMqttSocketHandler + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", +] + +_LOGGER = logging.getLogger(__name__) + +# Used by fixtures to handle incoming requests and prepare responses +MqttRequestHandler = Callable[[bytes], bytes | None] + + +@pytest.fixture(name="mock_paho_mqtt_create_connection") +def create_connection_fixture(mock_sock: Mock) -> Generator[None, None, None]: + """Fixture that overrides the MQTT socket creation to wire it up to the mock socket.""" + with patch("paho.mqtt.client.socket.create_connection", return_value=mock_sock): + yield + + +@pytest.fixture(name="mock_paho_mqtt_select") +def select_fixture(mock_sock: Mock, fake_mqtt_socket_handler: FakeMqttSocketHandler) -> Generator[None, None, None]: + """Fixture that overrides the MQTT client select calls to make select work on the mock socket. + + This patch select to activate our mock socket when ready with data. Internal mqtt sockets are + always ready since they are used internally to wake the select loop. Ours is ready if there + is data in the buffer. + """ + + def is_ready(sock: Any) -> bool: + return sock is not mock_sock or (fake_mqtt_socket_handler.pending() > 0) + + def handle_select(rlist: list, wlist: list, *args: Any) -> list: + return [list(filter(is_ready, rlist)), list(filter(is_ready, wlist))] + + with patch("paho.mqtt.client.select.select", side_effect=handle_select): + yield + + +@pytest.fixture(name="fake_mqtt_socket_handler") +def fake_mqtt_socket_handler_fixture( + mqtt_request_handler: MqttRequestHandler, mqtt_response_queue: Queue[bytes], log: CapturedRequestLog +) -> FakeMqttSocketHandler: + """Fixture that creates a fake MQTT broker.""" + return FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log) + + +@pytest.fixture(name="mock_sock") +def mock_sock_fixture(fake_mqtt_socket_handler: FakeMqttSocketHandler) -> Mock: + """Fixture that creates a mock socket connection and wires it to the handler.""" + mock_sock = Mock() + mock_sock.recv = fake_mqtt_socket_handler.handle_socket_recv + mock_sock.send = fake_mqtt_socket_handler.handle_socket_send + mock_sock.pending = fake_mqtt_socket_handler.pending + return mock_sock + + +@pytest.fixture(name="mqtt_received_requests") +def received_requests_fixture() -> Queue[bytes]: + """Fixture that provides access to the received requests.""" + return Queue() + + +@pytest.fixture(name="mqtt_response_queue") +def response_queue_fixture() -> Generator[Queue[bytes], None, None]: + """Fixture that provides a queue for enqueueing responses to be sent to the client under test.""" + response_queue: Queue[bytes] = Queue() + yield response_queue + assert response_queue.empty(), "Not all fake responses were consumed" + + +@pytest.fixture(name="mqtt_request_handler") +def mqtt_request_handler_fixture( + mqtt_received_requests: Queue[bytes], mqtt_response_queue: Queue[bytes] +) -> MqttRequestHandler: + """Fixture records incoming requests and replies with responses from the queue.""" + + def handle_request(client_request: bytes) -> bytes | None: + """Handle an incoming request from the client.""" + mqtt_received_requests.put(client_request) + + # Insert a prepared response into the response buffer + if not mqtt_response_queue.empty(): + return mqtt_response_queue.get() + return None + + return handle_request diff --git a/tests/fixtures/web_api_fixtures.py b/tests/fixtures/web_api_fixtures.py new file mode 100644 index 00000000..0b071387 --- /dev/null +++ b/tests/fixtures/web_api_fixtures.py @@ -0,0 +1,141 @@ +import re +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest +from aioresponses import aioresponses + +from tests.mock_data import HOME_DATA_RAW, HOME_DATA_SCENES_RAW, USER_DATA + + +@pytest.fixture +def skip_rate_limit() -> Generator[None, None, None]: + """Don't rate limit tests as they aren't actually hitting the api.""" + with ( + patch("roborock.web_api.RoborockApiClient._login_limiter.try_acquire"), + patch("roborock.web_api.RoborockApiClient._home_data_limiter.try_acquire"), + ): + yield + + +@pytest.fixture(name="mock_rest") +def mock_rest_fixture(skip_rate_limit: Any) -> aioresponses: + """Mock all rest endpoints so they won't hit real endpoints""" + with aioresponses() as mocked: + # Match the base URL and allow any query params + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/getUrlByEmail.*"), + status=200, + payload={ + "code": 200, + "data": {"country": "US", "countrycode": "1", "url": "https://usiot.roborock.com"}, + "msg": "success", + }, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/login.*"), + status=200, + payload={"code": 200, "data": USER_DATA, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/loginWithCode.*"), + status=200, + payload={"code": 200, "data": USER_DATA, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v1/sendEmailCode.*"), + status=200, + payload={"code": 200, "data": None, "msg": "success"}, + ) + mocked.get( + re.compile(r"https://.*iot\.roborock\.com/api/v1/getHomeDetail.*"), + status=200, + payload={ + "code": 200, + "data": {"deviceListOrder": None, "id": 123456, "name": "My Home", "rrHomeId": 123456, "tuyaHomeId": 0}, + "msg": "success", + }, + ) + mocked.get( + re.compile(r"https://api-.*\.roborock\.com/v2/user/homes*"), + status=200, + payload={"api": None, "code": 200, "result": HOME_DATA_RAW, "status": "ok", "success": True}, + ) + mocked.post( + re.compile(r"https://api-.*\.roborock\.com/nc/prepare"), + status=200, + payload={ + "api": None, + "result": {"r": "US", "s": "ffffff", "t": "eOf6d2BBBB"}, + "status": "ok", + "success": True, + }, + ) + + mocked.get( + re.compile(r"https://api-.*\.roborock\.com/user/devices/newadd/*"), + status=200, + payload={ + "api": "获取新增设备信息", + "result": { + "activeTime": 1737724598, + "attribute": None, + "cid": None, + "createTime": 0, + "deviceStatus": None, + "duid": "rand_duid", + "extra": "{}", + "f": False, + "featureSet": "0", + "fv": "02.16.12", + "iconUrl": "", + "lat": None, + "localKey": "random_lk", + "lon": None, + "name": "S7", + "newFeatureSet": "0000000000002000", + "online": True, + "productId": "rand_prod_id", + "pv": "1.0", + "roomId": None, + "runtimeEnv": None, + "setting": None, + "share": False, + "shareTime": None, + "silentOtaSwitch": False, + "sn": "Rand_sn", + "timeZoneId": "America/New_York", + "tuyaMigrated": False, + "tuyaUuid": None, + }, + "status": "ok", + "success": True, + }, + ) + mocked.get( + re.compile(r"https://api-.*\.roborock\.com/user/scene/device/.*"), + status=200, + payload={"api": None, "code": 200, "result": HOME_DATA_SCENES_RAW, "status": "ok", "success": True}, + ) + mocked.post( + re.compile(r"https://api-.*\.roborock\.com/user/scene/.*/execute"), + status=200, + payload={"api": None, "code": 200, "result": None, "status": "ok", "success": True}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v4/email/code/send.*"), + status=200, + payload={"code": 200, "data": None, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v3/key/sign.*"), + status=200, + payload={"code": 200, "data": {"k": "mock_k"}, "msg": "success"}, + ) + mocked.post( + re.compile(r"https://.*iot\.roborock\.com/api/v4/auth/email/login/code.*"), + status=200, + payload={"code": 200, "data": USER_DATA, "msg": "success"}, + ) + yield mocked diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 3f8a9485..24e6fbbd 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -13,16 +13,27 @@ from roborock.mqtt.roborock_session import RoborockMqttSession, create_mqtt_session from roborock.mqtt.session import MqttSessionException, MqttSessionUnauthorized from tests import mqtt_packet -from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber +from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", + "tests.fixtures.pahomqtt_fixtures", + "tests.fixtures.aiomqtt_fixtures", +] @pytest.fixture(autouse=True) -def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None: +def mqtt_server_fixture( + mock_paho_mqtt_create_connection: None, + mock_paho_mqtt_select: None, +) -> None: """Fixture to prepare a fake MQTT server.""" @pytest.fixture(autouse=True) -def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None: +def auto_mock_aiomqtt_client( + mock_aiomqtt_client: None, +) -> None: """Automatically use the mock mqtt client fixture.""" @@ -31,9 +42,12 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: """Automatically use the fast backoff fixture.""" -@pytest.fixture -def mock_mqtt_client() -> Generator[AsyncMock, None, None]: - """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" +@pytest.fixture(name="mqtt_client_lite") +def mqtt_client_lite_fixture() -> Generator[AsyncMock, None, None]: + """A fixture that provides a mocked aiomqtt Client. + + This is lighter weight that `mock_aiomqtt_client` that uses real sockets. + """ mock_client = AsyncMock() mock_client.messages = FakeAsyncIterator() @@ -48,38 +62,38 @@ def mock_mqtt_client() -> Generator[AsyncMock, None, None]: yield mock_client -async def test_session(push_response: Callable[[bytes], None]) -> None: +async def test_session(push_mqtt_response: Callable[[bytes], None]) -> None: """Test the MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber1 = Subscriber() unsub1 = await session.subscribe("topic-1", subscriber1.append) - push_response(mqtt_packet.gen_suback(mid=2)) + push_mqtt_response(mqtt_packet.gen_suback(mid=2)) subscriber2 = Subscriber() await session.subscribe("topic-2", subscriber2.append) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await subscriber1.wait() assert subscriber1.messages == [b"12345"] assert not subscriber2.messages - push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) await subscriber2.wait() assert subscriber2.messages == [b"67890"] - push_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) await subscriber1.wait() assert subscriber1.messages == [b"12345", b"ABC"] assert subscriber2.messages == [b"67890"] # Messages are no longer received after unsubscribing unsub1() - push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) assert subscriber1.messages == [b"12345", b"ABC"] assert session.connected @@ -87,12 +101,12 @@ async def test_session(push_response: Callable[[bytes], None]) -> None: assert not session.connected -async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> None: +async def test_session_no_subscribers(push_mqtt_response: Callable[[bytes], None]) -> None: """Test the MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) - push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected @@ -100,13 +114,13 @@ async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> assert not session.connected -async def test_publish_command(push_response: Callable[[bytes], None]) -> None: +async def test_publish_command(push_mqtt_response: Callable[[bytes], None]) -> None: """Test publishing during an MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await session.publish("topic-1", message=b"payload") assert session.connected @@ -129,13 +143,13 @@ async def __anext__(self) -> None: await asyncio.sleep(1) -async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None: +async def test_publish_failure(mqtt_client_lite: AsyncMock) -> None: """Test an MQTT error is received when publishing a message.""" session = await create_mqtt_session(FAKE_PARAMS) assert session.connected - mock_mqtt_client.publish.side_effect = aiomqtt.MqttError + mqtt_client_lite.publish.side_effect = aiomqtt.MqttError with pytest.raises(MqttSessionException, match="Error publishing message"): await session.publish("topic-1", message=b"payload") @@ -143,13 +157,13 @@ async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None: await session.close() -async def test_subscribe_failure(mock_mqtt_client: AsyncMock) -> None: +async def test_subscribe_failure(mqtt_client_lite: AsyncMock) -> None: """Test an MQTT error while subscribing.""" session = await create_mqtt_session(FAKE_PARAMS) assert session.connected - mock_mqtt_client.subscribe.side_effect = aiomqtt.MqttError + mqtt_client_lite.subscribe.side_effect = aiomqtt.MqttError subscriber1 = Subscriber() with pytest.raises(MqttSessionException, match="Error subscribing to topic"): @@ -159,20 +173,20 @@ async def test_subscribe_failure(mock_mqtt_client: AsyncMock) -> None: await session.close() -async def test_restart(push_response: Callable[[bytes], None]) -> None: +async def test_restart(push_mqtt_response: Callable[[bytes], None]) -> None: """Test restarting the MQTT session.""" - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(FAKE_PARAMS) assert session.connected # Subscribe to a topic - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber = Subscriber() await session.subscribe("topic-1", subscriber.append) # Verify we can receive messages - push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=b"12345")) await subscriber.wait() assert subscriber.messages == [b"12345"] @@ -184,20 +198,20 @@ async def test_restart(push_response: Callable[[bytes], None]) -> None: await asyncio.sleep(0.01) # We need to queue up a new connack for the reconnection - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) # And a suback for the resubscription. Since we created a new client, # the message ID resets to 1. - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) - push_response(mqtt_packet.gen_publish("topic-1", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=4, payload=b"67890")) await subscriber.wait() assert subscriber.messages == [b"12345", b"67890"] await session.close() -async def test_idle_timeout_resubscribe(mock_mqtt_client: AsyncMock) -> None: +async def test_idle_timeout_resubscribe(mqtt_client_lite: AsyncMock) -> None: """Test that resubscribing before idle timeout cancels the unsubscribe.""" # Create session with idle timeout @@ -220,12 +234,12 @@ async def test_idle_timeout_resubscribe(mock_mqtt_client: AsyncMock) -> None: await asyncio.sleep(0.01) # unsubscribe should NOT have been called because we resubscribed - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() await session.close() -async def test_idle_timeout_unsubscribe(mock_mqtt_client: AsyncMock) -> None: +async def test_idle_timeout_unsubscribe(mqtt_client_lite: AsyncMock) -> None: """Test that unsubscribe happens after idle timeout expires.""" # Create session with very short idle timeout for fast test @@ -244,12 +258,12 @@ async def test_idle_timeout_unsubscribe(mock_mqtt_client: AsyncMock) -> None: await asyncio.sleep(0.1) # unsubscribe should have been called after idle timeout - mock_mqtt_client.unsubscribe.assert_called_once_with(topic) + mqtt_client_lite.unsubscribe.assert_called_once_with(topic) await session.close() -async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> None: +async def test_idle_timeout_multiple_callbacks(mqtt_client_lite: AsyncMock) -> None: """Test that unsubscribe is delayed when multiple subscribers exist.""" # Create session with very short idle timeout for fast test @@ -271,7 +285,7 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N await asyncio.sleep(0.1) # unsubscribe should NOT have been called because subscriber2 is still active - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() # Unsubscribe second callback (NOW timer should start) unsub2() @@ -280,12 +294,12 @@ async def test_idle_timeout_multiple_callbacks(mock_mqtt_client: AsyncMock) -> N await asyncio.sleep(0.1) # Now unsubscribe should have been called - mock_mqtt_client.unsubscribe.assert_called_once_with(topic) + mqtt_client_lite.unsubscribe.assert_called_once_with(topic) await session.close() -async def test_subscription_reuse(mock_mqtt_client: AsyncMock) -> None: +async def test_subscription_reuse(mqtt_client_lite: AsyncMock) -> None: """Test that subscriptions are reused and not duplicated.""" session = RoborockMqttSession(FAKE_PARAMS) await session.start() @@ -296,32 +310,32 @@ async def test_subscription_reuse(mock_mqtt_client: AsyncMock) -> None: unsub1 = await session.subscribe("topic1", cb1) # Verify subscribe called - mock_mqtt_client.subscribe.assert_called_with("topic1") - mock_mqtt_client.subscribe.reset_mock() + mqtt_client_lite.subscribe.assert_called_with("topic1") + mqtt_client_lite.subscribe.reset_mock() # 2. Second subscription (same topic) cb2 = Mock() unsub2 = await session.subscribe("topic1", cb2) # Verify subscribe NOT called - mock_mqtt_client.subscribe.assert_not_called() + mqtt_client_lite.subscribe.assert_not_called() # 3. Unsubscribe one unsub1() # Verify unsubscribe NOT called (still have cb2) - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() # 4. Unsubscribe second (starts idle timer) unsub2() # Verify unsubscribe NOT called yet (idle) - mock_mqtt_client.unsubscribe.assert_not_called() + mqtt_client_lite.unsubscribe.assert_not_called() # 5. Resubscribe during idle cb3 = Mock() _ = await session.subscribe("topic1", cb3) # Verify subscribe NOT called (reused) - mock_mqtt_client.subscribe.assert_not_called() + mqtt_client_lite.subscribe.assert_not_called() await session.close() @@ -365,7 +379,7 @@ async def test_connect_failure( await create_mqtt_session(FAKE_PARAMS) -async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: +async def test_diagnostics_data(push_mqtt_response: Callable[[bytes], None]) -> None: """Test the MQTT session.""" diagnostics = Diagnostics() @@ -373,7 +387,7 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: params = copy.deepcopy(FAKE_PARAMS) params.diagnostics = diagnostics - push_response(mqtt_packet.gen_connack(rc=0, flags=2)) + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) session = await create_mqtt_session(params) assert session.connected @@ -386,24 +400,24 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: assert data.get("dispatch_message_count") is None assert data.get("close") is None - push_response(mqtt_packet.gen_suback(mid=1)) + push_mqtt_response(mqtt_packet.gen_suback(mid=1)) subscriber1 = Subscriber() unsub1 = await session.subscribe("topic-1", subscriber1.append) - push_response(mqtt_packet.gen_suback(mid=2)) + push_mqtt_response(mqtt_packet.gen_suback(mid=2)) subscriber2 = Subscriber() await session.subscribe("topic-2", subscriber2.append) - push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345")) await subscriber1.wait() assert subscriber1.messages == [b"12345"] assert not subscriber2.messages - push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) + push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890")) await subscriber2.wait() assert subscriber2.messages == [b"67890"] - push_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC")) await subscriber1.wait() assert subscriber1.messages == [b"12345", b"ABC"] assert subscriber2.messages == [b"67890"] @@ -418,7 +432,7 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: # Messages are no longer received after unsubscribing unsub1() - push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) + push_mqtt_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored")) assert subscriber1.messages == [b"12345", b"ABC"] assert session.connected diff --git a/tests/test_a01_api.py b/tests/test_a01_api.py index 6c1f7dac..911a1791 100644 --- a/tests/test_a01_api.py +++ b/tests/test_a01_api.py @@ -20,9 +20,10 @@ RoborockZeoProtocol, ) from roborock.version_a01_apis import RoborockMqttClientA01 +from tests.fixtures.logging import CapturedRequestLog +from tests.protocols.common import build_a01_message from . import mqtt_packet -from .conftest import QUEUE_TIMEOUT, CapturedRequestLog from .mock_data import ( HOME_DATA_RAW, LOCAL_KEY, @@ -31,11 +32,16 @@ WASHER_PRODUCT, ZEO_ONE_DEVICE, ) -from .protocols.common import build_a01_message +QUEUE_TIMEOUT = 10 RELEASE_TIMEOUT = 2 +pytest_plugins = [ + "tests.fixtures.pahomqtt_fixtures", +] + + @pytest.fixture(name="category") def category_fixture() -> RoborockCategory: return RoborockCategory.WASHING_MACHINE @@ -43,8 +49,8 @@ def category_fixture() -> RoborockCategory: @pytest.fixture(name="a01_mqtt_client") async def a01_mqtt_client_fixture( - mock_create_connection: None, - mock_select: None, + mock_paho_mqtt_create_connection: None, + mock_paho_mqtt_select: None, category: RoborockCategory, ) -> AsyncGenerator[RoborockMqttClientA01, None]: user_data = UserData.from_dict(USER_DATA) @@ -74,15 +80,15 @@ async def a01_mqtt_client_fixture( @pytest.fixture(name="connected_a01_mqtt_client") async def connected_a01_mqtt_client_fixture( - response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 + mqtt_response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 ) -> AsyncGenerator[RoborockMqttClientA01, None]: - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - response_queue.put(mqtt_packet.gen_suback(1, 0)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) + mqtt_response_queue.put(mqtt_packet.gen_suback(1, 0)) await a01_mqtt_client.async_connect() yield a01_mqtt_client -async def test_async_connect(received_requests: Queue, connected_a01_mqtt_client: RoborockMqttClientA01) -> None: +async def test_async_connect(mqtt_received_requests: Queue, connected_a01_mqtt_client: RoborockMqttClientA01) -> None: """Test connecting to the MQTT broker.""" assert connected_a01_mqtt_client.is_connected() @@ -95,20 +101,20 @@ async def test_async_connect(received_requests: Queue, connected_a01_mqtt_client # Broker received a connect and subscribe. Disconnect packet is not # guaranteed to be captured by the time the async_disconnect returns - assert received_requests.qsize() >= 2 # Connect and Subscribe + assert mqtt_received_requests.qsize() >= 2 # Connect and Subscribe async def test_connect_failure( - received_requests: Queue, response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 ) -> None: """Test the broker responding with a connect failure.""" - response_queue.put(mqtt_packet.gen_connack(rc=1)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=1)) with pytest.raises(RoborockException, match="Failed to connect"): await a01_mqtt_client.async_connect() assert not a01_mqtt_client.is_connected() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt async def test_disconnect_already_disconnected(connected_a01_mqtt_client: RoborockMqttClientA01) -> None: @@ -141,11 +147,11 @@ async def test_async_release(connected_a01_mqtt_client: RoborockMqttClientA01) - async def test_subscribe_failure( - received_requests: Queue, response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, a01_mqtt_client: RoborockMqttClientA01 ) -> None: """Test the broker responding with the wrong message type on subscribe.""" - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) with ( patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), @@ -153,7 +159,7 @@ async def test_subscribe_failure( ): await a01_mqtt_client.async_connect() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt # NOTE: The client is "connected" but not "subscribed" and cannot recover # from this state without disconnecting first. This can likely be improved. @@ -162,7 +168,7 @@ async def test_subscribe_failure( # Attempting to reconnect is a no-op since the client already thinks it is connected await a01_mqtt_client.async_connect() assert a01_mqtt_client.is_connected() - assert received_requests.qsize() == 1 + assert mqtt_received_requests.qsize() == 1 def build_rpc_response(message: dict[Any, Any]) -> bytes: @@ -171,8 +177,8 @@ def build_rpc_response(message: dict[Any, Any]) -> bytes: async def test_update_zeo_values( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_a01_mqtt_client: RoborockMqttClientA01, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -189,7 +195,7 @@ async def test_update_zeo_values( 218: 0, # Washing left. Testing zero int value } ) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) data = await connected_a01_mqtt_client.update_values( [ @@ -214,8 +220,8 @@ async def test_update_zeo_values( @pytest.mark.parametrize("category", [RoborockCategory.WET_DRY_VAC]) async def test_update_dyad_values( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_a01_mqtt_client: RoborockMqttClientA01, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -231,7 +237,7 @@ async def test_update_dyad_values( 224: 0, # AUTO_DRY_MODE off } ) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) data = await connected_a01_mqtt_client.update_values( [ @@ -253,25 +259,25 @@ async def test_update_dyad_values( async def test_set_value( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_a01_mqtt_client: RoborockMqttClientA01, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, ) -> None: """Test sending an arbitrary MQTT message and parsing the response.""" # Clear existing messages received during setup - assert received_requests.qsize() == 2 - assert received_requests.get(block=True, timeout=QUEUE_TIMEOUT) - assert received_requests.get(block=True, timeout=QUEUE_TIMEOUT) - assert received_requests.empty() + assert mqtt_received_requests.qsize() == 2 + assert mqtt_received_requests.get(block=True, timeout=QUEUE_TIMEOUT) + assert mqtt_received_requests.get(block=True, timeout=QUEUE_TIMEOUT) + assert mqtt_received_requests.empty() # Prepare the response message message = build_rpc_response({}) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) await connected_a01_mqtt_client.set_value(RoborockZeoProtocol.STATE, "spinning") - assert received_requests.get(block=True) + assert mqtt_received_requests.get(block=True) assert snapshot == log diff --git a/tests/test_api.py b/tests/test_api.py index b428d2fd..2f1b3c92 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,6 +23,7 @@ from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from roborock.version_1_apis import RoborockMqttClientV1 from roborock.web_api import PreparedRequest, RoborockApiClient +from tests.fixtures.logging import CapturedRequestLog from tests.mock_data import ( BASE_URL_REQUEST, GET_CODE_RESPONSE, @@ -34,7 +35,33 @@ ) from . import mqtt_packet -from .conftest import CapturedRequestLog + +QUEUE_TIMEOUT = 10 + +pytest_plugins = [ + "tests.fixtures.pahomqtt_fixtures", +] + + +@pytest.fixture(name="mqtt_client") +async def mqtt_client( + mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None +) -> AsyncGenerator[RoborockMqttClientV1, None]: + user_data = UserData.from_dict(USER_DATA) + home_data = HomeData.from_dict(HOME_DATA_RAW) + device_info = DeviceData( + device=home_data.devices[0], + model=home_data.products[0].model, + ) + client = RoborockMqttClientV1(user_data, device_info, queue_timeout=QUEUE_TIMEOUT) + try: + yield client + finally: + if not client.is_connected(): + try: + await client.async_release() + except Exception: + pass def test_can_create_prepared_request(): @@ -143,10 +170,10 @@ async def test_get_prop(): @pytest.fixture(name="connected_mqtt_client") async def connected_mqtt_client_fixture( - response_queue: Queue, mqtt_client: RoborockMqttClientV1 + mqtt_response_queue: Queue, mqtt_client: RoborockMqttClientV1 ) -> AsyncGenerator[RoborockMqttClientV1, None]: - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - response_queue.put(mqtt_packet.gen_suback(1, 0)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) + mqtt_response_queue.put(mqtt_packet.gen_suback(1, 0)) await mqtt_client.async_connect() yield mqtt_client if mqtt_client.is_connected(): @@ -156,7 +183,7 @@ async def connected_mqtt_client_fixture( pass -async def test_async_connect(received_requests: Queue, connected_mqtt_client: RoborockMqttClientV1) -> None: +async def test_async_connect(mqtt_received_requests: Queue, connected_mqtt_client: RoborockMqttClientV1) -> None: """Test connecting to the MQTT broker.""" assert connected_mqtt_client.is_connected() @@ -169,20 +196,20 @@ async def test_async_connect(received_requests: Queue, connected_mqtt_client: Ro # Broker received a connect and subscribe. Disconnect packet is not # guaranteed to be captured by the time the async_disconnect returns - assert received_requests.qsize() >= 2 # Connect and Subscribe + assert mqtt_received_requests.qsize() >= 2 # Connect and Subscribe async def test_connect_failure_response( - received_requests: Queue, response_queue: Queue, mqtt_client: RoborockMqttClientV1 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, mqtt_client: RoborockMqttClientV1 ) -> None: """Test the broker responding with a connect failure.""" - response_queue.put(mqtt_packet.gen_connack(rc=1)) + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=1)) with pytest.raises(RoborockException, match="Failed to connect"): await mqtt_client.async_connect() assert not mqtt_client.is_connected() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt async def test_disconnect_already_disconnected(connected_mqtt_client: RoborockMqttClientV1) -> None: @@ -209,8 +236,8 @@ async def test_disconnect_failure(connected_mqtt_client: RoborockMqttClientV1) - async def test_disconnect_failure_response( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_mqtt_client: RoborockMqttClientV1, caplog: pytest.LogCaptureFixture, ) -> None: @@ -218,7 +245,7 @@ async def test_disconnect_failure_response( # Enqueue a failed message -- however, the client does not process any # further messages and there is no parsing error, and no failed log messages. - response_queue.put(mqtt_packet.gen_disconnect(reason_code=1)) + mqtt_response_queue.put(mqtt_packet.gen_disconnect(reason_code=1)) assert connected_mqtt_client.is_connected() with caplog.at_level(logging.ERROR): await connected_mqtt_client.async_disconnect() @@ -233,19 +260,18 @@ async def test_async_release(connected_mqtt_client: RoborockMqttClientV1) -> Non async def test_subscribe_failure( - received_requests: Queue, response_queue: Queue, mqtt_client: RoborockMqttClientV1 + mqtt_received_requests: Queue, mqtt_response_queue: Queue, mqtt_client: RoborockMqttClientV1 ) -> None: """Test the broker responding with the wrong message type on subscribe.""" - response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) - + mqtt_response_queue.put(mqtt_packet.gen_connack(rc=0, flags=2)) with ( patch("roborock.cloud_api.mqtt.Client.subscribe", return_value=(mqtt.MQTT_ERR_NO_CONN, None)), pytest.raises(RoborockException, match="Failed to subscribe"), ): await mqtt_client.async_connect() - assert received_requests.qsize() == 1 # Connect attempt + assert mqtt_received_requests.qsize() == 1 # Connect attempt # NOTE: The client is "connected" but not "subscribed" and cannot recover # from this state without disconnecting first. This can likely be improved. @@ -254,7 +280,7 @@ async def test_subscribe_failure( # Attempting to reconnect is a no-op since the client already thinks it is connected await mqtt_client.async_connect() assert mqtt_client.is_connected() - assert received_requests.qsize() == 1 + assert mqtt_received_requests.qsize() == 1 def build_rpc_response(message: dict[str, Any]) -> bytes: @@ -276,8 +302,8 @@ def build_rpc_response(message: dict[str, Any]) -> bytes: async def test_get_room_mapping( - received_requests: Queue, - response_queue: Queue, + mqtt_received_requests: Queue, + mqtt_response_queue: Queue, connected_mqtt_client: RoborockMqttClientV1, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -291,7 +317,7 @@ async def test_get_room_mapping( "result": [[16, "2362048"], [17, "2362044"]], } ) - response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) + mqtt_response_queue.put(mqtt_packet.gen_publish(MQTT_PUBLISH_TOPIC, payload=message)) with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id): room_mapping = await connected_mqtt_client.get_room_mapping() diff --git a/tests/test_local_api_v1.py b/tests/test_local_api_v1.py index 281f7f9a..c5209423 100644 --- a/tests/test_local_api_v1.py +++ b/tests/test_local_api_v1.py @@ -1,23 +1,123 @@ """Tests for the Roborock Local Client V1.""" +import asyncio import json -from collections.abc import AsyncGenerator +import logging +import warnings +from collections.abc import AsyncGenerator, Callable, Generator from queue import Queue from typing import Any -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest import syrupy -from roborock.data import RoomMapping +from roborock import HomeData +from roborock.data import DeviceData, RoomMapping from roborock.exceptions import RoborockException from roborock.protocol import MessageParser from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -from roborock.version_1_apis import RoborockLocalClientV1 +from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1 +from tests.fixtures.logging import CapturedRequestLog +from tests.mock_data import HOME_DATA_RAW, TEST_LOCAL_API_HOST -from .conftest import CapturedRequestLog from .mock_data import LOCAL_KEY +_LOGGER = logging.getLogger(__name__) + +pytest_plugins = [ + "tests.fixtures.logging_fixtures", +] + +QUEUE_TIMEOUT = 10 + +LocalRequestHandler = Callable[[bytes], bytes | None] + + +@pytest.fixture(name="local_received_requests") +def received_requests_fixture() -> Queue[bytes]: + """Fixture that provides access to the received requests.""" + return Queue() + + +@pytest.fixture(name="local_response_queue") +def response_queue_fixture() -> Generator[Queue[bytes], None, None]: + """Fixture that provides a queue for enqueueing responses to be sent back to the client under test.""" + response_queue: Queue[bytes] = Queue() + yield response_queue + if not response_queue.empty(): + warnings.warn("Not all fake responses were consumed") + + +@pytest.fixture(name="local_request_handler") +def local_request_handler_fixture( + local_received_requests: Queue[bytes], local_response_queue: Queue[bytes] +) -> LocalRequestHandler: + """Fixture records incoming requests and replies with responses from the queue.""" + + def handle_request(client_request: bytes) -> bytes | None: + """Handle an incoming request from the client.""" + local_received_requests.put(client_request) + + # Insert a prepared response into the response buffer + if not local_response_queue.empty(): + return local_response_queue.get() + return None + + return handle_request + + +@pytest.fixture(name="mock_create_local_connection") +def create_local_connection_fixture( + local_request_handler: LocalRequestHandler, log: CapturedRequestLog +) -> Generator[None, None, None]: + """Fixture that overrides the transport creation to wire it up to the mock socket.""" + + async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args) -> tuple[Any, Any]: + protocol = protocol_factory() + + def handle_write(data: bytes) -> None: + _LOGGER.debug("Received: %s", data) + response = local_request_handler(data) + log.add_log_entry("[local >]", data) + if response is not None: + _LOGGER.debug("Replying with %s", response) + log.add_log_entry("[local <]", response) + loop = asyncio.get_running_loop() + loop.call_soon(protocol.data_received, response) + + closed = asyncio.Event() + + mock_transport = Mock() + mock_transport.write = handle_write + mock_transport.close = closed.set + mock_transport.is_reading = lambda: not closed.is_set() + + return (mock_transport, "proto") + + with patch("roborock.version_1_apis.roborock_local_client_v1.get_running_loop") as mock_loop: + mock_loop.return_value.create_connection.side_effect = create_connection + yield + + +@pytest.fixture(name="local_client") +async def local_client_fixture(mock_create_local_connection: None) -> AsyncGenerator[RoborockLocalClientV1, None]: + home_data = HomeData.from_dict(HOME_DATA_RAW) + device_info = DeviceData( + device=home_data.devices[0], + model=home_data.products[0].model, + host=TEST_LOCAL_API_HOST, + ) + client = RoborockLocalClientV1(device_info, queue_timeout=QUEUE_TIMEOUT) + try: + yield client + finally: + if not client.is_connected(): + try: + await client.async_release() + except Exception: + pass + def build_rpc_response(seq: int, message: dict[str, Any]) -> bytes: """Build an encoded RPC response message.""" @@ -45,18 +145,18 @@ def build_raw_response(protocol: RoborockMessageProtocol, seq: int, payload: byt async def test_async_connect( local_client: RoborockLocalClientV1, - received_requests: Queue, - response_queue: Queue, + local_received_requests: Queue, + local_response_queue: Queue, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, ) -> None: """Test that we can connect to the Roborock device.""" - response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) - response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) await local_client.async_connect() assert local_client.is_connected() - assert received_requests.qsize() == 2 + assert local_received_requests.qsize() == 2 await local_client.async_disconnect() assert not local_client.is_connected() @@ -66,18 +166,18 @@ async def test_async_connect( @pytest.fixture(name="connected_local_client") async def connected_local_client_fixture( - response_queue: Queue, + local_response_queue: Queue, local_client: RoborockLocalClientV1, ) -> AsyncGenerator[RoborockLocalClientV1, None]: - response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) - response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, b"ignored")) + local_response_queue.put(build_raw_response(RoborockMessageProtocol.PING_RESPONSE, 2, b"ignored")) await local_client.async_connect() yield local_client async def test_get_room_mapping( - received_requests: Queue, - response_queue: Queue, + local_received_requests: Queue, + local_response_queue: Queue, connected_local_client: RoborockLocalClientV1, snapshot: syrupy.SnapshotAssertion, log: CapturedRequestLog, @@ -93,7 +193,7 @@ async def test_get_room_mapping( "result": [[16, "2362048"], [17, "2362044"]], }, ) - response_queue.put(message) + local_response_queue.put(message) with patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id): room_mapping = await connected_local_client.get_room_mapping() @@ -107,8 +207,8 @@ async def test_get_room_mapping( async def test_retry_request( - received_requests: Queue, - response_queue: Queue, + local_received_requests: Queue, + local_response_queue: Queue, connected_local_client: RoborockLocalClientV1, ) -> None: """Test sending an arbitrary MQTT message and parsing the response.""" @@ -122,7 +222,7 @@ async def test_retry_request( "result": "retry", }, ) - response_queue.put(retry_message) + local_response_queue.put(retry_message) with ( patch("roborock.protocols.v1_protocol.get_next_int", return_value=test_request_id), diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 796fd008..b280277c 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -1,12 +1,24 @@ import re +from typing import Any import aiohttp +import pytest from aioresponses.compat import normalize_url from roborock import HomeData, HomeDataScene, UserData from roborock.web_api import IotLoginInfo, RoborockApiClient from tests.mock_data import HOME_DATA_RAW, USER_DATA +pytest_plugins = [ + "tests.fixtures.web_api_fixtures", +] + + +@pytest.fixture(autouse=True) +def auto_mock_rest_fixture(mock_rest: Any) -> None: + """Auto use the mock rest fixture for all tests in this module.""" + pass + async def test_pass_login_flow() -> None: """Test that we can login with a password and we get back the correct userdata object."""