From efe70172c2d9c1a108400912092e849471453f82 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 20 Dec 2025 17:12:14 -0800 Subject: [PATCH 1/2] fix: Add a hook for handling background rate limit errors When a rate limit error is received the background loop will jump to the maximum backoff (now 6 hours) and will also invoke a callback so that the caller can decide to re-authenticate or stop harder. The new backoff follows this trajectory, where the change in behavior is introduced after 15 attempts: - attempt 1: wait 10 seconds - attempt 5: waits 50 seconds - attempt 7: waits 2 minutes - attempt 10: waits 6 minutes - attempt 15: waits 32 minutes - attempt 17: waits 2 hours - attempt 20: waits 6 hours --- roborock/devices/device_manager.py | 4 +- roborock/mqtt/roborock_session.py | 19 +++- roborock/mqtt/session.py | 10 +++ tests/mqtt/test_roborock_session.py | 130 ++++++++++++++++++++++++---- 4 files changed, 141 insertions(+), 22 deletions(-) diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 1fb5ec40..c8129802 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -20,7 +20,7 @@ from roborock.exceptions import RoborockException from roborock.map.map_parser import MapParserConfig from roborock.mqtt.roborock_session import create_lazy_mqtt_session -from roborock.mqtt.session import MqttSession +from roborock.mqtt.session import MqttSession, SessionUnauthorizedHook from roborock.protocol import create_mqtt_params from roborock.web_api import RoborockApiClient, UserWebApiClient @@ -173,6 +173,7 @@ async def create_device_manager( map_parser_config: MapParserConfig | None = None, session: aiohttp.ClientSession | None = None, ready_callback: DeviceReadyCallback | None = None, + mqtt_session_unauthorized_hook: SessionUnauthorizedHook | None = None, ) -> DeviceManager: """Convenience function to create and initialize a DeviceManager. @@ -196,6 +197,7 @@ async def create_device_manager( mqtt_params = create_mqtt_params(user_data.rriot) mqtt_params.diagnostics = diagnostics.subkey("mqtt_session") + mqtt_params.unauthorized_hook = mqtt_session_unauthorized_hook mqtt_session = await create_lazy_mqtt_session(mqtt_params) def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice: diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index e3a9af03..d8f88373 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -31,7 +31,7 @@ # Exponential backoff parameters MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10) -MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30) +MAX_BACKOFF_INTERVAL = datetime.timedelta(hours=6) BACKOFF_MULTIPLIER = 1.5 @@ -79,6 +79,7 @@ def __init__( self._idle_timers: dict[str, asyncio.Task[None]] = {} self._diagnostics = params.diagnostics self._health_manager = HealthManager(self.restart) + self._unauthorized_hook = params.unauthorized_hook @property def connected(self) -> bool: @@ -199,14 +200,28 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No _LOGGER.debug("Received message: %s", message) with self._diagnostics.timer("dispatch_message"): self._listeners(message.topic.value, message.payload) + except MqttCodeError as err: + self._diagnostics.increment(f"connect_failure:{err.rc}") + if start_future and not start_future.done(): + _LOGGER.debug("MQTT error starting session: %s", err) + start_future.set_exception(err) + else: + _LOGGER.debug("MQTT error: %s", err) + if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED and self._unauthorized_hook: + _LOGGER.info("MQTT unauthorized/rate-limit error received, setting backoff to maximum") + self._unauthorized_hook() + self._backoff = MAX_BACKOFF_INTERVAL + raise except MqttError as err: + self._diagnostics.increment("connect_failure:unknown") if start_future and not start_future.done(): _LOGGER.info("MQTT error starting session: %s", err) start_future.set_exception(err) else: _LOGGER.info("MQTT error: %s", err) raise - except Exception as err: + except (RuntimeError, Exception) as err: + self._diagnostics.increment("connect_failure:uncaught") # This error is thrown when the MQTT loop is cancelled # and the generator is not stopped. if "generator didn't stop" in str(err) or "generator didn't yield" in str(err): diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py index 20d54a42..ac040cde 100644 --- a/roborock/mqtt/session.py +++ b/roborock/mqtt/session.py @@ -10,6 +10,8 @@ DEFAULT_TIMEOUT = 30.0 +SessionUnauthorizedHook = Callable[[], None] + @dataclass class MqttParams: @@ -41,6 +43,14 @@ class MqttParams: shared MQTT session diagnostics are included in the overall diagnostics. """ + unauthorized_hook: SessionUnauthorizedHook | None = None + """Optional hook invoked when an unauthorized error is received. + + This may be invoked by the background reconnect logic when an + unauthorized error is received from the broker. The caller may use + this hook to refresh credentials or take other actions as needed. + """ + class MqttSession(ABC): """An MQTT session for sending and receiving messages.""" diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 24e6fbbd..15526b66 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -4,6 +4,7 @@ import copy import datetime from collections.abc import Callable, Generator +from typing import Any from unittest.mock import AsyncMock, Mock, patch import aiomqtt @@ -42,20 +43,65 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: """Automatically use the fast backoff fixture.""" -@pytest.fixture(name="mqtt_client_lite") -def mqtt_client_lite_fixture() -> Generator[AsyncMock, None, None]: +class FakeAsyncIterator: + """Fake async iterator that waits for messages to arrive, but they never do. + + This is used for testing exceptions in other client functions. + """ + + def __init__(self) -> None: + self.loop = True + + def __aiter__(self): + return self + + async def __anext__(self) -> None: + """Iterator that does not generate any messages.""" + while self.loop: + await asyncio.sleep(0.01) + + +@pytest.fixture(name="message_iterator") +def message_iterator_fixture() -> FakeAsyncIterator: + """Fixture to provide a side effect for creating the MQTT client.""" + return FakeAsyncIterator() + + +@pytest.fixture(name="mock_client") +def mock_client_fixture(message_iterator: FakeAsyncIterator) -> Generator[AsyncMock, None, None]: """A fixture that provides a mocked aiomqtt Client. This is lighter weight that `mock_aiomqtt_client` that uses real sockets. """ mock_client = AsyncMock() - mock_client.messages = FakeAsyncIterator() + mock_client.messages = message_iterator + return mock_client + + +@pytest.fixture(name="create_client_side_effect") +def create_client_side_effect_fixture() -> Exception | None: + """Fixture to provide a side effect for creating the MQTT client.""" + return None + +@pytest.fixture(name="mock_aenter_client") +def mock_aenter_client_fixture(mock_client: AsyncMock, create_client_side_effect: Exception | None) -> AsyncMock: + """Fixture to provide a side effect for creating the MQTT client.""" mock_aenter = AsyncMock() mock_aenter.return_value = mock_client + mock_aenter.side_effect = create_client_side_effect + return mock_aenter + + +@pytest.fixture(name="mqtt_client_lite") +def mqtt_client_lite_fixture( + mock_client: AsyncMock, + mock_aenter_client: AsyncMock, +) -> Generator[AsyncMock, None, None]: + """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" mock_shim = Mock() - mock_shim.return_value.__aenter__ = mock_aenter + mock_shim.return_value.__aenter__ = mock_aenter_client mock_shim.return_value.__aexit__ = AsyncMock() with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim): @@ -128,21 +174,6 @@ async def test_publish_command(push_mqtt_response: Callable[[bytes], None]) -> N assert not session.connected -class FakeAsyncIterator: - """Fake async iterator that waits for messages to arrive, but they never do. - - This is used for testing exceptions in other client functions. - """ - - def __aiter__(self): - return self - - async def __anext__(self) -> None: - """Iterator that does not generate any messages.""" - while True: - await asyncio.sleep(1) - - async def test_publish_failure(mqtt_client_lite: AsyncMock) -> None: """Test an MQTT error is received when publishing a message.""" @@ -446,3 +477,64 @@ async def test_diagnostics_data(push_mqtt_response: Callable[[bytes], None]) -> assert data.get("subscribe_count") == 2 assert data.get("dispatch_message_count") == 3 assert data.get("close") == 1 + + +@pytest.mark.parametrize( + ("create_client_side_effect"), + [ + # Unauthorized + aiomqtt.MqttCodeError(rc=135), + ], +) +async def test_session_unauthorized_hook(mqtt_client_lite: AsyncMock) -> None: + """Test the MQTT session.""" + + unauthorized = asyncio.Event() + + params = copy.deepcopy(FAKE_PARAMS) + params.unauthorized_hook = unauthorized.set + + with pytest.raises(MqttSessionUnauthorized): + await create_mqtt_session(params) + + assert unauthorized.is_set() + + +async def test_session_unauthorized_after_start( + mock_aenter_client: AsyncMock, + message_iterator: FakeAsyncIterator, + mqtt_client_lite: AsyncMock, + push_mqtt_response: Callable[[bytes], None], +) -> None: + """Test the MQTT session.""" + + # Configure a hook that is notified of unauthorized errors + unauthorized = asyncio.Event() + params = copy.deepcopy(FAKE_PARAMS) + params.unauthorized_hook = unauthorized.set + + # The client will succeed on first connection attempt, then fail with + # unauthorized messages on all future attempts. + request_count = 0 + + def succeed_then_fail_unauthorized() -> Any: + nonlocal request_count + request_count += 1 + if request_count == 1: + return mqtt_client_lite + raise aiomqtt.MqttCodeError(rc=135) + + mock_aenter_client.side_effect = succeed_then_fail_unauthorized + # Don't produce messages, just exit and restart to reconnect + message_iterator.loop = False + + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) + + session = await create_mqtt_session(params) + assert session.connected + + try: + async with asyncio.timeout(10): + assert await unauthorized.wait() + finally: + await session.close() From c2e70868ef94bc56ac9b404d2897a346133303f6 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 20 Dec 2025 17:39:26 -0800 Subject: [PATCH 2/2] chore: fix lint errors --- roborock/devices/device_manager.py | 3 +++ roborock/mqtt/roborock_session.py | 2 +- roborock/mqtt/session.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index c8129802..1717e94c 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -183,6 +183,9 @@ async def create_device_manager( map_parser_config: Optional configuration for parsing maps. session: Optional aiohttp ClientSession to use for HTTP requests. ready_callback: Optional callback to be notified when a device is ready. + mqtt_session_unauthorized_hook: Optional hook for MQTT session unauthorized + events which may indicate rate limiting or revoked credentials. The + caller may use this to refresh authentication tokens as needed. Returns: An initialized DeviceManager with discovered devices. diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index d8f88373..70d3ea1b 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -220,7 +220,7 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No else: _LOGGER.info("MQTT error: %s", err) raise - except (RuntimeError, Exception) as err: + except Exception as err: self._diagnostics.increment("connect_failure:uncaught") # This error is thrown when the MQTT loop is cancelled # and the generator is not stopped. diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py index ac040cde..9d8b57ad 100644 --- a/roborock/mqtt/session.py +++ b/roborock/mqtt/session.py @@ -45,7 +45,7 @@ class MqttParams: unauthorized_hook: SessionUnauthorizedHook | None = None """Optional hook invoked when an unauthorized error is received. - + This may be invoked by the background reconnect logic when an unauthorized error is received from the broker. The caller may use this hook to refresh credentials or take other actions as needed.