Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from roborock.exceptions import RoborockException
from roborock.map.map_parser import MapParserConfig
from roborock.mqtt.roborock_session import create_lazy_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.mqtt.session import MqttSession, SessionUnauthorizedHook
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient, UserWebApiClient

Expand Down Expand Up @@ -173,6 +173,7 @@ async def create_device_manager(
map_parser_config: MapParserConfig | None = None,
session: aiohttp.ClientSession | None = None,
ready_callback: DeviceReadyCallback | None = None,
mqtt_session_unauthorized_hook: SessionUnauthorizedHook | None = None,
) -> DeviceManager:
"""Convenience function to create and initialize a DeviceManager.

Expand All @@ -182,6 +183,9 @@ async def create_device_manager(
map_parser_config: Optional configuration for parsing maps.
session: Optional aiohttp ClientSession to use for HTTP requests.
ready_callback: Optional callback to be notified when a device is ready.
mqtt_session_unauthorized_hook: Optional hook for MQTT session unauthorized
events which may indicate rate limiting or revoked credentials. The
caller may use this to refresh authentication tokens as needed.

Returns:
An initialized DeviceManager with discovered devices.
Expand All @@ -196,6 +200,7 @@ async def create_device_manager(

mqtt_params = create_mqtt_params(user_data.rriot)
mqtt_params.diagnostics = diagnostics.subkey("mqtt_session")
mqtt_params.unauthorized_hook = mqtt_session_unauthorized_hook
mqtt_session = await create_lazy_mqtt_session(mqtt_params)

def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
Expand Down
17 changes: 16 additions & 1 deletion roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# Exponential backoff parameters
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
MAX_BACKOFF_INTERVAL = datetime.timedelta(hours=6)
BACKOFF_MULTIPLIER = 1.5


Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
self._idle_timers: dict[str, asyncio.Task[None]] = {}
self._diagnostics = params.diagnostics
self._health_manager = HealthManager(self.restart)
self._unauthorized_hook = params.unauthorized_hook

@property
def connected(self) -> bool:
Expand Down Expand Up @@ -199,14 +200,28 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No
_LOGGER.debug("Received message: %s", message)
with self._diagnostics.timer("dispatch_message"):
self._listeners(message.topic.value, message.payload)
except MqttCodeError as err:
self._diagnostics.increment(f"connect_failure:{err.rc}")
if start_future and not start_future.done():
_LOGGER.debug("MQTT error starting session: %s", err)
start_future.set_exception(err)
else:
_LOGGER.debug("MQTT error: %s", err)
if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED and self._unauthorized_hook:
_LOGGER.info("MQTT unauthorized/rate-limit error received, setting backoff to maximum")
self._unauthorized_hook()
self._backoff = MAX_BACKOFF_INTERVAL
raise
except MqttError as err:
self._diagnostics.increment("connect_failure:unknown")
if start_future and not start_future.done():
_LOGGER.info("MQTT error starting session: %s", err)
start_future.set_exception(err)
else:
_LOGGER.info("MQTT error: %s", err)
raise
except Exception as err:
self._diagnostics.increment("connect_failure:uncaught")
# This error is thrown when the MQTT loop is cancelled
# and the generator is not stopped.
if "generator didn't stop" in str(err) or "generator didn't yield" in str(err):
Expand Down
10 changes: 10 additions & 0 deletions roborock/mqtt/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

DEFAULT_TIMEOUT = 30.0

SessionUnauthorizedHook = Callable[[], None]


@dataclass
class MqttParams:
Expand Down Expand Up @@ -41,6 +43,14 @@ class MqttParams:
shared MQTT session diagnostics are included in the overall diagnostics.
"""

unauthorized_hook: SessionUnauthorizedHook | None = None
"""Optional hook invoked when an unauthorized error is received.

This may be invoked by the background reconnect logic when an
unauthorized error is received from the broker. The caller may use
this hook to refresh credentials or take other actions as needed.
"""


class MqttSession(ABC):
"""An MQTT session for sending and receiving messages."""
Expand Down
130 changes: 111 additions & 19 deletions tests/mqtt/test_roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import datetime
from collections.abc import Callable, Generator
from typing import Any
from unittest.mock import AsyncMock, Mock, patch

import aiomqtt
Expand Down Expand Up @@ -42,20 +43,65 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None:
"""Automatically use the fast backoff fixture."""


@pytest.fixture(name="mqtt_client_lite")
def mqtt_client_lite_fixture() -> Generator[AsyncMock, None, None]:
class FakeAsyncIterator:
"""Fake async iterator that waits for messages to arrive, but they never do.

This is used for testing exceptions in other client functions.
"""

def __init__(self) -> None:
self.loop = True

def __aiter__(self):
return self

async def __anext__(self) -> None:
"""Iterator that does not generate any messages."""
while self.loop:
await asyncio.sleep(0.01)


@pytest.fixture(name="message_iterator")
def message_iterator_fixture() -> FakeAsyncIterator:
"""Fixture to provide a side effect for creating the MQTT client."""
return FakeAsyncIterator()


@pytest.fixture(name="mock_client")
def mock_client_fixture(message_iterator: FakeAsyncIterator) -> Generator[AsyncMock, None, None]:
"""A fixture that provides a mocked aiomqtt Client.

This is lighter weight that `mock_aiomqtt_client` that uses real sockets.
"""
mock_client = AsyncMock()
mock_client.messages = FakeAsyncIterator()
mock_client.messages = message_iterator
return mock_client


@pytest.fixture(name="create_client_side_effect")
def create_client_side_effect_fixture() -> Exception | None:
"""Fixture to provide a side effect for creating the MQTT client."""
return None


@pytest.fixture(name="mock_aenter_client")
def mock_aenter_client_fixture(mock_client: AsyncMock, create_client_side_effect: Exception | None) -> AsyncMock:
"""Fixture to provide a side effect for creating the MQTT client."""
mock_aenter = AsyncMock()
mock_aenter.return_value = mock_client
mock_aenter.side_effect = create_client_side_effect
return mock_aenter


@pytest.fixture(name="mqtt_client_lite")
def mqtt_client_lite_fixture(
mock_client: AsyncMock,
mock_aenter_client: AsyncMock,
) -> Generator[AsyncMock, None, None]:
"""Fixture to create a mock MQTT client with patched aiomqtt.Client."""

mock_shim = Mock()
mock_shim.return_value.__aenter__ = mock_aenter
mock_shim.return_value.__aenter__ = mock_aenter_client
mock_shim.return_value.__aexit__ = AsyncMock()

with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
Expand Down Expand Up @@ -128,21 +174,6 @@ async def test_publish_command(push_mqtt_response: Callable[[bytes], None]) -> N
assert not session.connected


class FakeAsyncIterator:
"""Fake async iterator that waits for messages to arrive, but they never do.

This is used for testing exceptions in other client functions.
"""

def __aiter__(self):
return self

async def __anext__(self) -> None:
"""Iterator that does not generate any messages."""
while True:
await asyncio.sleep(1)


async def test_publish_failure(mqtt_client_lite: AsyncMock) -> None:
"""Test an MQTT error is received when publishing a message."""

Expand Down Expand Up @@ -446,3 +477,64 @@ async def test_diagnostics_data(push_mqtt_response: Callable[[bytes], None]) ->
assert data.get("subscribe_count") == 2
assert data.get("dispatch_message_count") == 3
assert data.get("close") == 1


@pytest.mark.parametrize(
("create_client_side_effect"),
[
# Unauthorized
aiomqtt.MqttCodeError(rc=135),
],
)
async def test_session_unauthorized_hook(mqtt_client_lite: AsyncMock) -> None:
"""Test the MQTT session."""

unauthorized = asyncio.Event()

params = copy.deepcopy(FAKE_PARAMS)
params.unauthorized_hook = unauthorized.set

with pytest.raises(MqttSessionUnauthorized):
await create_mqtt_session(params)

assert unauthorized.is_set()


async def test_session_unauthorized_after_start(
mock_aenter_client: AsyncMock,
message_iterator: FakeAsyncIterator,
mqtt_client_lite: AsyncMock,
push_mqtt_response: Callable[[bytes], None],
) -> None:
"""Test the MQTT session."""

# Configure a hook that is notified of unauthorized errors
unauthorized = asyncio.Event()
params = copy.deepcopy(FAKE_PARAMS)
params.unauthorized_hook = unauthorized.set

# The client will succeed on first connection attempt, then fail with
# unauthorized messages on all future attempts.
request_count = 0

def succeed_then_fail_unauthorized() -> Any:
nonlocal request_count
request_count += 1
if request_count == 1:
return mqtt_client_lite
raise aiomqtt.MqttCodeError(rc=135)

mock_aenter_client.side_effect = succeed_then_fail_unauthorized
# Don't produce messages, just exit and restart to reconnect
message_iterator.loop = False

push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))

session = await create_mqtt_session(params)
assert session.connected

try:
async with asyncio.timeout(10):
assert await unauthorized.wait()
finally:
await session.close()