Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
532 changes: 0 additions & 532 deletions tests/conftest.py

This file was deleted.

3 changes: 1 addition & 2 deletions tests/devices/test_a01_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
RoborockMessage,
RoborockMessageProtocol,
)

from ..conftest import FakeChannel
from tests.fixtures.channel_fixtures import FakeChannel


@pytest.fixture
Expand Down
5 changes: 2 additions & 3 deletions tests/devices/test_v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from roborock.roborock_typing import RoborockCommand

from .. import mock_data
from ..conftest import FakeChannel
from tests import mock_data
from tests.fixtures.channel_fixtures import FakeChannel

USER_DATA = UserData.from_dict(mock_data.USER_DATA)
TEST_DEVICE_UID = "abc123"
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/traits/a01/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from roborock.devices.traits.a01 import DyadApi, ZeoApi
from roborock.roborock_message import RoborockDyadDataProtocol, RoborockMessageProtocol, RoborockZeoProtocol
from tests.conftest import FakeChannel
from tests.fixtures.channel_fixtures import FakeChannel
from tests.protocols.common import build_a01_message


Expand Down
2 changes: 1 addition & 1 deletion tests/devices/traits/b01/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from roborock.exceptions import RoborockException
from roborock.protocols.b01_protocol import B01_VERSION
from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol
from tests.conftest import FakeChannel
from tests.fixtures.channel_fixtures import FakeChannel


def build_b01_message(message: dict[Any, Any], msg_id: str = "123456789", seq: int = 2020) -> RoborockMessage:
Expand Down
7 changes: 7 additions & 0 deletions tests/e2e/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
"""End-to-end tests package."""

pytest_plugins = [
"tests.fixtures.logging_fixtures",
"tests.fixtures.local_async_fixtures",
"tests.fixtures.pahomqtt_fixtures",
"tests.fixtures.aiomqtt_fixtures",
]
64 changes: 21 additions & 43 deletions tests/e2e/test_local_session.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""End-to-end tests for LocalChannel using fake sockets."""

import asyncio
from collections.abc import AsyncGenerator, Callable, Generator
from queue import Queue
from typing import Any
from unittest.mock import Mock, patch
from collections.abc import AsyncGenerator
from unittest.mock import patch

import pytest

from roborock.devices.local_channel import LocalChannel
from roborock.protocol import create_local_decoder, create_local_encoder
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from tests.conftest import RequestHandler
from tests.mock_data import LOCAL_KEY

TEST_HOST = "192.168.1.100"
Expand All @@ -21,35 +18,8 @@
TEST_RANDOM = 13579


@pytest.fixture(name="mock_create_local_connection")
def create_local_connection_fixture(request_handler: RequestHandler) -> Generator[None, None, None]:
"""Fixture that overrides the transport creation to wire it up to the mock socket."""

async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]:
protocol = protocol_factory()

def handle_write(data: bytes) -> None:
response = request_handler(data)
if response is not None:
# Call data_received directly to avoid loop scheduling issues in test
protocol.data_received(response)

closed = asyncio.Event()

mock_transport = Mock()
mock_transport.write = handle_write
mock_transport.close = closed.set
mock_transport.is_reading = lambda: not closed.is_set()

return (mock_transport, protocol)

with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
mock_loop.return_value.create_connection.side_effect = create_connection
yield


@pytest.fixture(name="local_channel")
async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
with patch(
"roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID
):
Expand Down Expand Up @@ -80,19 +50,23 @@ def build_response(


async def test_connect(
local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
) -> None:
"""Test connecting to the device."""
# Queue HELLO response with payload to ensure it can be parsed
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
local_response_queue.put_nowait(
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
)

await local_channel.connect()

assert local_channel.is_connected
assert received_requests.qsize() == 1
assert local_received_requests.qsize() == 1

# Verify HELLO request
request_bytes = received_requests.get()
request_bytes = await local_received_requests.get()
# Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None
# which causes MessageParser to fail parsing. For now we verify the raw bytes.

Expand All @@ -104,17 +78,21 @@ async def test_connect(


async def test_send_command(
local_channel: LocalChannel, response_queue: Queue[bytes], received_requests: Queue[bytes]
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
) -> None:
"""Test sending a command."""
# Queue HELLO response
response_queue.put(build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM))
local_response_queue.put_nowait(
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
)

await local_channel.connect()

# Clear requests from handshake
while not received_requests.empty():
received_requests.get()
while not local_received_requests.empty():
await local_received_requests.get()

# Send command
cmd_seq = 123
Expand All @@ -127,8 +105,8 @@ async def test_send_command(
await local_channel.publish(msg)

# Verify request
assert received_requests.qsize() == 1
request_bytes = received_requests.get()
request_bytes = await local_received_requests.get()
assert local_received_requests.empty()

# Decode request
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
Expand Down
25 changes: 13 additions & 12 deletions tests/e2e/test_mqtt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from roborock.protocol import MessageParser
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from tests import mqtt_packet
from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber
from tests.mock_data import LOCAL_KEY
from tests.mqtt_fixtures import FAKE_PARAMS, Subscriber


@pytest.fixture(autouse=True)
def auto_mock_mqtt_client(mock_mqtt_client_fixture: None) -> None:
def auto_mock_mqtt_client(mock_aiomqtt_client: None) -> None:
"""Automatically use the mock mqtt client fixture."""


Expand All @@ -33,7 +33,7 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None:


@pytest.fixture(autouse=True)
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
def mqtt_server_fixture(mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None) -> None:
"""Fixture to mock the MQTT connection.

This is here to pull in the mock socket pixtures into all tests used here.
Expand All @@ -42,10 +42,10 @@ def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None

@pytest.fixture(name="session")
async def session_fixture(
push_response: Callable[[bytes], None],
push_mqtt_response: Callable[[bytes], None],
) -> AsyncGenerator[MqttSession, None]:
"""Fixture to create a new connected MQTT session."""
push_response(mqtt_packet.gen_connack(rc=0, flags=2))
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
try:
Expand All @@ -54,12 +54,12 @@ async def session_fixture(
await session.close()


async def test_session_e2e_receive_message(push_response: Callable[[bytes], None], session: MqttSession) -> None:
async def test_session_e2e_receive_message(push_mqtt_response: Callable[[bytes], None], session: MqttSession) -> None:
"""Test receiving a real Roborock message through the session."""
assert session.connected

# Subscribe to the topic. We'll next construct and push a message.
push_response(mqtt_packet.gen_suback(mid=1))
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
subscriber = Subscriber()
await session.subscribe("topic-1", subscriber.append)

Expand All @@ -71,12 +71,13 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)

# Simulate receiving the message from the broker
push_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))

# Verify it was dispatched to the subscriber
await subscriber.wait()
assert len(subscriber.messages) == 1
received_payload = subscriber.messages[0]
assert isinstance(received_payload, bytes)
assert received_payload == payload

# Verify the message payload contents
Expand All @@ -90,8 +91,8 @@ async def test_session_e2e_receive_message(push_response: Callable[[bytes], None


async def test_session_e2e_publish_message(
push_response: Callable[[bytes], None],
received_requests: Queue,
push_mqtt_response: Callable[[bytes], None],
mqtt_received_requests: Queue,
session: MqttSession,
) -> None:
"""Test publishing a real Roborock message."""
Expand All @@ -109,8 +110,8 @@ async def test_session_e2e_publish_message(
# Verify what was sent to the broker
# We expect the payload to be present in the sent bytes
found = False
while not received_requests.empty():
request = received_requests.get()
while not mqtt_received_requests.empty():
request = mqtt_received_requests.get()
if payload in request:
found = True
break
Expand Down
Empty file added tests/fixtures/__init__.py
Empty file.
43 changes: 8 additions & 35 deletions tests/mqtt_fixtures.py → tests/fixtures/aiomqtt_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,11 @@
import paho.mqtt.client as mqtt
import pytest

from roborock.mqtt.session import MqttParams
from tests.conftest import FakeSocketHandler
from .mqtt import FakeMqttSocketHandler

FAKE_PARAMS = MqttParams(
host="localhost",
port=1883,
tls=False,
username="username",
password="password",
timeout=10.0,
)


class Subscriber:
"""Mock subscriber class.

We use this to hold on to received messages for verification.
"""

def __init__(self) -> None:
self.messages: list[bytes] = []
self._event = asyncio.Event()

def append(self, message: bytes) -> None:
self.messages.append(message)
self._event.set()

async def wait(self) -> None:
await asyncio.wait_for(self._event.wait(), timeout=1.0)
self._event.clear()


@pytest.fixture
async def mock_mqtt_client_fixture() -> AsyncGenerator[None, None]:
@pytest.fixture(name="mock_aiomqtt_client")
async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]:
"""Fixture to patch the MQTT underlying sync client.

The tests use fake sockets, so this ensures that the async mqtt client does not
Expand Down Expand Up @@ -94,11 +65,13 @@ def fast_backoff_fixture() -> Generator[None, None, None]:


@pytest.fixture
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
def push_mqtt_response(
mqtt_response_queue: Queue, fake_mqtt_socket_handler: FakeMqttSocketHandler
) -> Callable[[bytes], None]:
"""Fixture to push a response to the client."""

def _push(data: bytes) -> None:
response_queue.put(data)
fake_socket_handler.push_response()
mqtt_response_queue.put(data)
fake_mqtt_socket_handler.push_response()

return _push
53 changes: 53 additions & 0 deletions tests/fixtures/channel_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections.abc import Callable
from unittest.mock import AsyncMock, MagicMock

from roborock.mqtt.health_manager import HealthManager
from roborock.protocols.v1_protocol import LocalProtocolVersion
from roborock.roborock_message import RoborockMessage


class FakeChannel:
"""A fake channel that handles publish and subscribe calls."""

def __init__(self):
"""Initialize the fake channel."""
self.subscribers: list[Callable[[RoborockMessage], None]] = []
self.published_messages: list[RoborockMessage] = []
self.response_queue: list[RoborockMessage] = []
self._is_connected = False
self.publish_side_effect: Exception | None = None
self.publish = AsyncMock(side_effect=self._publish)
self.subscribe = AsyncMock(side_effect=self._subscribe)
self.connect = AsyncMock(side_effect=self._connect)
self.close = MagicMock(side_effect=self._close)
self.protocol_version = LocalProtocolVersion.V1
self.restart = AsyncMock()
self.health_manager = HealthManager(self.restart)

async def _connect(self) -> None:
self._is_connected = True

def _close(self) -> None:
self._is_connected = False

@property
def is_connected(self) -> bool:
"""Return true if connected."""
return self._is_connected

async def _publish(self, message: RoborockMessage) -> None:
"""Simulate publishing a message and triggering a response."""
self.published_messages.append(message)
if self.publish_side_effect:
raise self.publish_side_effect
# When a message is published, simulate a response
if self.response_queue:
response = self.response_queue.pop(0)
# Give a chance for the subscriber to be registered
for subscriber in list(self.subscribers):
subscriber(response)

async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
"""Simulate subscribing to messages."""
self.subscribers.append(callback)
return lambda: self.subscribers.remove(callback)
Loading