diff --git a/roborock/protocol.py b/roborock/protocol.py index 20bb6829..828a432a 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -163,12 +163,12 @@ def _l01_iv(timestamp: int, nonce: int, sequence: int) -> bytes: return digest[:12] @staticmethod - def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int) -> bytes: + def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int | None = None) -> bytes: """Derive AAD for L01 protocol.""" return ( sequence.to_bytes(4, "big") + connect_nonce.to_bytes(4, "big") - + ack_nonce.to_bytes(4, "big") + + (ack_nonce.to_bytes(4, "big") if ack_nonce is not None else b"") + nonce.to_bytes(4, "big") + timestamp.to_bytes(4, "big") ) @@ -181,7 +181,7 @@ def encrypt_gcm_l01( sequence: int, nonce: int, connect_nonce: int, - ack_nonce: int, + ack_nonce: int | None = None, ) -> bytes: """Encrypt plaintext for L01 protocol using AES-256-GCM.""" if not isinstance(plaintext, bytes): diff --git a/tests/e2e/__snapshots__/test_local_session.ambr b/tests/e2e/__snapshots__/test_local_session.ambr new file mode 100644 index 00000000..65658492 --- /dev/null +++ b/tests/e2e/__snapshots__/test_local_session.ambr @@ -0,0 +1,54 @@ +# serializer version: 1 +# name: test_connect + [local >] + 00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h| + 00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c| + [local <] + 00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h| + 00000010 a6 a2 23 00 01 00 10 cb 93 c7 39 b9 21 53 43 48 |..#.......9.!SCH| + 00000020 83 b3 c2 af 0f 51 2c da 9e ea 3b |.....Q,...;| +# --- +# name: test_l01_session + [local >] + 00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h| + 00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c| + [local <] + 00000000 00 |.| + [local >] + 00000000 00 00 00 15 4c 30 31 00 00 00 01 00 00 23 82 68 |....L01......#.h| + 00000010 a6 a2 25 00 00 ee 2f 30 e8 |..%.../0.| + [local <] + 00000000 00 00 00 29 4c 30 31 00 00 00 01 00 00 00 17 68 |...)L01........h| + 00000010 a6 a2 23 00 01 00 12 a0 4a ec 75 88 03 75 0f d2 |..#.....J.u..u..| + 00000020 40 33 69 02 f4 71 50 72 f3 81 56 80 f4 |@3i..qPr..V..| + [local >] + 00000000 00 00 00 3e 4c 30 31 00 00 00 7b 00 00 23 83 68 |...>L01...{..#.h| + 00000010 a6 a2 26 00 65 00 27 9e fd c2 42 b7 01 b4 eb 9c |..&.e.'...B.....| + 00000020 00 84 4f fd 51 1f bc a5 65 12 c2 dc 45 0e 21 cb |..O.Q...e...E.!.| + 00000030 45 dc bb 0a ba 16 84 28 a7 33 e5 e2 fa a8 f1 f2 |E......(.3......| + 00000040 ec f4 |..| + [local <] + 00000000 00 00 00 37 4c 30 31 00 00 00 7b 00 00 00 17 68 |...7L01...{....h| + 00000010 a6 a2 27 00 66 00 20 b7 72 49 8a 64 eb 16 a5 71 |..'.f. .rI.d...q| + 00000020 73 eb 9e 7e 37 64 3e 75 c0 70 ea 39 4e de 82 1f |s..~7d>u.p.9N...| + 00000030 e2 29 86 de 4a 7b 38 20 55 12 8a |.)..J{8 U..| +# --- +# name: test_send_command + [local >] + 00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h| + 00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c| + [local <] + 00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h| + 00000010 a6 a2 23 00 01 00 10 cb 93 c7 39 b9 21 53 43 48 |..#.......9.!SCH| + 00000020 83 b3 c2 af 0f 51 2c da 9e ea 3b |.....Q,...;| + [local >] + 00000000 00 00 00 37 31 2e 30 00 00 00 7b 00 00 23 83 68 |...71.0...{..#.h| + 00000010 a6 a2 25 00 65 00 20 91 5b 1f 43 34 d5 22 47 9f |..%.e. .[.C4."G.| + 00000020 59 4e 45 53 85 f9 c6 6e f2 eb 27 eb 6d 03 d8 92 |YNES...n..'.m...| + 00000030 5b 30 83 b4 a4 ea f5 85 be 38 57 |[0.......8W| + [local <] + 00000000 00 00 00 37 31 2e 30 00 00 00 7b 00 00 00 17 68 |...71.0...{....h| + 00000010 a6 a2 26 00 66 00 20 07 8b 28 60 a8 08 18 12 47 |..&.f. ..(`....G| + 00000020 05 20 3e f5 53 e3 fd 4a cc 03 72 7b b4 2c d9 84 |. >.S..J..r{.,..| + 00000030 7f 4b 18 d8 76 7d 5c 65 87 7c 2d |.K..v}\e.|-| +# --- diff --git a/tests/e2e/__snapshots__/test_mqtt_session.ambr b/tests/e2e/__snapshots__/test_mqtt_session.ambr new file mode 100644 index 00000000..aa23b772 --- /dev/null +++ b/tests/e2e/__snapshots__/test_mqtt_session.ambr @@ -0,0 +1,32 @@ +# serializer version: 1 +# name: test_session_e2e_publish_message + [mqtt <] + 00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..| + [mqtt >] + 00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....| + 00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw| + 00000020 6f 72 64 |ord| + [mqtt >] + 00000000 30 41 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |0A..topic-1.1.0.| + 00000010 00 01 c8 00 00 23 82 68 a6 a2 23 00 65 00 20 91 |.....#.h..#.e. .| + 00000020 22 f1 91 1a 6e 89 71 ca ec 2d 44 2a 16 57 e7 5b |"...n.q..-D*.W.[| + 00000030 4a 9a c8 97 4b 13 37 3b f5 81 13 45 7c e7 48 03 |J...K.7;...E|.H.| + 00000040 99 71 bf |.q.| +# --- +# name: test_session_e2e_receive_message + [mqtt <] + 00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..| + [mqtt >] + 00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....| + 00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw| + 00000020 6f 72 64 |ord| + [mqtt <] + 00000000 90 04 00 01 00 00 |......| + [mqtt >] + 00000000 82 0d 00 01 00 00 07 74 6f 70 69 63 2d 31 00 |.......topic-1.| + [mqtt <] + 00000000 30 31 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |01..topic-1.1.0.| + 00000010 00 00 7b 00 00 23 82 68 a6 a2 23 00 66 00 10 45 |..{..#.h..#.f..E| + 00000020 3b c3 2b 12 a6 77 d9 55 f6 e0 89 f5 93 a5 30 5d |;.+..w.U......0]| + 00000030 a0 72 fa |.r.| +# --- diff --git a/tests/e2e/test_local_session.py b/tests/e2e/test_local_session.py index 347923da..6d8458b1 100644 --- a/tests/e2e/test_local_session.py +++ b/tests/e2e/test_local_session.py @@ -2,63 +2,59 @@ import asyncio from collections.abc import AsyncGenerator -from unittest.mock import patch import pytest +import syrupy from roborock.devices.local_channel import LocalChannel -from roborock.protocol import create_local_decoder, create_local_encoder +from roborock.protocol import MessageParser, create_local_decoder +from roborock.protocols.v1_protocol import LocalProtocolVersion from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from tests.fixtures.logging import CapturedRequestLog +from tests.fixtures.mqtt import Subscriber from tests.mock_data import LOCAL_KEY TEST_HOST = "192.168.1.100" TEST_DEVICE_UID = "test_device_uid" -TEST_CONNECT_NONCE = 12345 -TEST_ACK_NONCE = 67890 -TEST_RANDOM = 13579 +TEST_RANDOM = 23 @pytest.fixture(name="local_channel") async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]: - with patch( - "roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID - ): - channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID) - yield channel - channel.close() + channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID) + yield channel + channel.close() -def build_response( +def build_raw_response( protocol: RoborockMessageProtocol, seq: int, payload: bytes, - random: int, + version: LocalProtocolVersion = LocalProtocolVersion.V1, + connect_nonce: int | None = None, + ack_nonce: int | None = None, ) -> bytes: """Build an encoded response message.""" - if protocol == RoborockMessageProtocol.HELLO_RESPONSE: - encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=None) - else: - encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE) - - msg = RoborockMessage( + message = RoborockMessage( protocol=protocol, - random=random, + random=23, seq=seq, payload=payload, + version=version.value.encode(), ) - return encoder(msg) + return MessageParser.build(message, local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=ack_nonce) async def test_connect( local_channel: LocalChannel, local_response_queue: asyncio.Queue[bytes], local_received_requests: asyncio.Queue[bytes], + log: CapturedRequestLog, + snapshot: syrupy.SnapshotAssertion, ) -> None: """Test connecting to the device.""" # Queue HELLO response with payload to ensure it can be parsed - local_response_queue.put_nowait( - build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM) - ) + local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok")) await local_channel.connect() @@ -76,17 +72,19 @@ async def test_connect( protocol_bytes = request_bytes[19:21] assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST + assert snapshot == log + async def test_send_command( local_channel: LocalChannel, local_response_queue: asyncio.Queue[bytes], local_received_requests: asyncio.Queue[bytes], + log: CapturedRequestLog, + snapshot: syrupy.SnapshotAssertion, ) -> None: """Test sending a command.""" # Queue HELLO response - local_response_queue.put_nowait( - build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM) - ) + local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok")) await local_channel.connect() @@ -101,16 +99,145 @@ async def test_send_command( seq=cmd_seq, payload=b'{"method":"get_status"}', ) + # Prepare a fake response to the command. + local_response_queue.put_nowait( + build_raw_response(RoborockMessageProtocol.RPC_RESPONSE, cmd_seq, payload=b'{"status": "ok"}') + ) + + subscriber = Subscriber() + unsub = await local_channel.subscribe(subscriber.append) await local_channel.publish(msg) - # Verify request + # Verify request received by the server request_bytes = await local_received_requests.get() assert local_received_requests.empty() # Decode request - decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE) + decoder = create_local_decoder(local_key=LOCAL_KEY) + msgs = list(decoder(request_bytes)) + assert len(msgs) == 1 + assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST + assert msgs[0].payload == b'{"method":"get_status"}' + assert msgs[0].version == LocalProtocolVersion.V1.value.encode() + + # Verify response received by subscriber + await subscriber.wait() + assert len(subscriber.messages) == 1 + response_message = subscriber.messages[0] + assert isinstance(response_message, RoborockMessage) + assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE + assert response_message.payload == b'{"status": "ok"}' + + unsub() + + assert snapshot == log + + +async def test_l01_session( + local_channel: LocalChannel, + local_response_queue: asyncio.Queue[bytes], + local_received_requests: asyncio.Queue[bytes], + log: CapturedRequestLog, + snapshot: syrupy.SnapshotAssertion, +) -> None: + """Test connecting to a device that speaks the L01 protocol. + + Note that this test currently has a delay because the actual local client + will delay before retrying with L01 after a failed 1.0 attempt. This should + also be improved in the actual client itself, but likely requires a closer + look at the actual device response in that scenario or moving to a serial + request/response behavior rather than publish/subscribe. + """ + # Client first attempts 1.0 and we reply with a fake invalid response. The + # response is arbitrary, and this could be improved by capturing a real L01 + # device response to a 1.0 message. + local_response_queue.put_nowait(b"\x00") + # The client attempts L01 protocol as a followup. The connect nonce uses + # a deterministic number from deterministic_message_fixtures. + connect_nonce = 9090 + local_response_queue.put_nowait( + build_raw_response( + RoborockMessageProtocol.HELLO_RESPONSE, + 1, + payload=b"ok", + version=LocalProtocolVersion.L01, + connect_nonce=connect_nonce, + ack_nonce=None, + ) + ) + + await local_channel.connect() + + assert local_channel.is_connected + + # Verify 1.0 HELLO request + request_bytes = await local_received_requests.get() + # Protocol is at offset 19 (2 bytes) + # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19 + assert len(request_bytes) >= 21 + protocol_bytes = request_bytes[19:21] + assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST + + # Verify L01 HELLO request + request_bytes = await local_received_requests.get() + # Protocol is at offset 19 (2 bytes) + # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19 + assert len(request_bytes) >= 21 + protocol_bytes = request_bytes[19:21] + assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST + + assert local_received_requests.empty() + + # Verify the channel switched to L01 protocol + assert local_channel.protocol_version == LocalProtocolVersion.L01.value + + # We have established a connection. Now send some messages. + # Publish an L01 command. Currently the caller of the local channel needs to + # determine the protocol version to use, but this could be pushed inside of + # the channel in the future. + cmd_seq = 123 + msg = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_REQUEST, + seq=cmd_seq, + payload=b'{"method":"get_status"}', + version=b"L01", + ) + # Prepare a fake response to the command. + local_response_queue.put_nowait( + build_raw_response( + RoborockMessageProtocol.RPC_RESPONSE, + cmd_seq, + payload=b'{"status": "ok"}', + version=LocalProtocolVersion.L01, + connect_nonce=connect_nonce, + ack_nonce=TEST_RANDOM, + ) + ) + + # Set up a subscriber to listen for the response then publish the message. + subscriber = Subscriber() + unsub = await local_channel.subscribe(subscriber.append) + await local_channel.publish(msg) + + # Verify request received by the server + request_bytes = await local_received_requests.get() + decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=TEST_RANDOM) msgs = list(decoder(request_bytes)) assert len(msgs) == 1 assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST assert msgs[0].payload == b'{"method":"get_status"}' + assert msgs[0].version == LocalProtocolVersion.L01.value.encode() + + # Verify fake response published by the server, received by subscriber + await subscriber.wait() + assert len(subscriber.messages) == 1 + response_message = subscriber.messages[0] + assert isinstance(response_message, RoborockMessage) + assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE + assert response_message.payload == b'{"status": "ok"}' + assert response_message.version == LocalProtocolVersion.L01.value.encode() + + unsub() + + assert snapshot == log diff --git a/tests/e2e/test_mqtt_session.py b/tests/e2e/test_mqtt_session.py index 294bf5e8..de33e0b0 100644 --- a/tests/e2e/test_mqtt_session.py +++ b/tests/e2e/test_mqtt_session.py @@ -12,12 +12,14 @@ from queue import Queue import pytest +import syrupy from roborock.mqtt.roborock_session import create_mqtt_session from roborock.mqtt.session import MqttSession from roborock.protocol import MessageParser from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from tests import mqtt_packet +from tests.fixtures.logging import CapturedRequestLog from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber from tests.mock_data import LOCAL_KEY @@ -54,7 +56,12 @@ async def session_fixture( await session.close() -async def test_session_e2e_receive_message(push_mqtt_response: Callable[[bytes], None], session: MqttSession) -> None: +async def test_session_e2e_receive_message( + push_mqtt_response: Callable[[bytes], None], + session: MqttSession, + log: CapturedRequestLog, + snapshot: syrupy.SnapshotAssertion, +) -> None: """Test receiving a real Roborock message through the session.""" assert session.connected @@ -89,11 +96,15 @@ async def test_session_e2e_receive_message(push_mqtt_response: Callable[[bytes], # The payload in parsed_msg should be the decrypted bytes assert parsed_msg.payload == b'{"result":"ok"}' + assert snapshot == log + async def test_session_e2e_publish_message( push_mqtt_response: Callable[[bytes], None], mqtt_received_requests: Queue, session: MqttSession, + log: CapturedRequestLog, + snapshot: syrupy.SnapshotAssertion, ) -> None: """Test publishing a real Roborock message.""" @@ -117,3 +128,5 @@ async def test_session_e2e_publish_message( break assert found, "Published payload not found in sent requests" + + assert snapshot == log diff --git a/tests/fixtures/local_async_fixtures.py b/tests/fixtures/local_async_fixtures.py index 394dbca5..d804df82 100644 --- a/tests/fixtures/local_async_fixtures.py +++ b/tests/fixtures/local_async_fixtures.py @@ -7,6 +7,8 @@ import pytest +from .logging import CapturedRequestLog + _LOGGER = logging.getLogger(__name__) AsyncLocalRequestHandler = Callable[[bytes], Awaitable[bytes | None]] @@ -48,6 +50,7 @@ async def handle_request(client_request: bytes) -> bytes | None: @pytest.fixture(name="mock_async_create_local_connection") def create_local_connection_fixture( local_async_request_handler: AsyncLocalRequestHandler, + log: CapturedRequestLog, ) -> Generator[None, None, None]: """Fixture that overrides the transport creation to wire it up to the mock socket.""" @@ -57,9 +60,11 @@ async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *a protocol = protocol_factory() async def handle_write(data: bytes) -> None: + log.add_log_entry("[local >]", data) response = await local_async_request_handler(data) if response is not None: # Call data_received directly to avoid loop scheduling issues in test + log.add_log_entry("[local <]", response) protocol.data_received(response) def start_handle_write(data: bytes) -> None: diff --git a/tests/fixtures/mqtt.py b/tests/fixtures/mqtt.py index 793bb3f4..08489de3 100644 --- a/tests/fixtures/mqtt.py +++ b/tests/fixtures/mqtt.py @@ -69,6 +69,7 @@ def push_response(self) -> None: # Enqueue a response to be sent back to the client in the buffer. # The buffer will be emptied when the client calls recv() on the socket _LOGGER.debug("Queued: 0x%s", response.hex()) + self.log.add_log_entry("[mqtt <]", response) self.response_buf.write(response)