diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 1fb5ec40..1717e94c 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -20,7 +20,7 @@ from roborock.exceptions import RoborockException from roborock.map.map_parser import MapParserConfig from roborock.mqtt.roborock_session import create_lazy_mqtt_session -from roborock.mqtt.session import MqttSession +from roborock.mqtt.session import MqttSession, SessionUnauthorizedHook from roborock.protocol import create_mqtt_params from roborock.web_api import RoborockApiClient, UserWebApiClient @@ -173,6 +173,7 @@ async def create_device_manager( map_parser_config: MapParserConfig | None = None, session: aiohttp.ClientSession | None = None, ready_callback: DeviceReadyCallback | None = None, + mqtt_session_unauthorized_hook: SessionUnauthorizedHook | None = None, ) -> DeviceManager: """Convenience function to create and initialize a DeviceManager. @@ -182,6 +183,9 @@ async def create_device_manager( map_parser_config: Optional configuration for parsing maps. session: Optional aiohttp ClientSession to use for HTTP requests. ready_callback: Optional callback to be notified when a device is ready. + mqtt_session_unauthorized_hook: Optional hook for MQTT session unauthorized + events which may indicate rate limiting or revoked credentials. The + caller may use this to refresh authentication tokens as needed. Returns: An initialized DeviceManager with discovered devices. @@ -196,6 +200,7 @@ async def create_device_manager( mqtt_params = create_mqtt_params(user_data.rriot) mqtt_params.diagnostics = diagnostics.subkey("mqtt_session") + mqtt_params.unauthorized_hook = mqtt_session_unauthorized_hook mqtt_session = await create_lazy_mqtt_session(mqtt_params) def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice: diff --git a/roborock/mqtt/roborock_session.py b/roborock/mqtt/roborock_session.py index e3a9af03..70d3ea1b 100644 --- a/roborock/mqtt/roborock_session.py +++ b/roborock/mqtt/roborock_session.py @@ -31,7 +31,7 @@ # Exponential backoff parameters MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10) -MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30) +MAX_BACKOFF_INTERVAL = datetime.timedelta(hours=6) BACKOFF_MULTIPLIER = 1.5 @@ -79,6 +79,7 @@ def __init__( self._idle_timers: dict[str, asyncio.Task[None]] = {} self._diagnostics = params.diagnostics self._health_manager = HealthManager(self.restart) + self._unauthorized_hook = params.unauthorized_hook @property def connected(self) -> bool: @@ -199,7 +200,20 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No _LOGGER.debug("Received message: %s", message) with self._diagnostics.timer("dispatch_message"): self._listeners(message.topic.value, message.payload) + except MqttCodeError as err: + self._diagnostics.increment(f"connect_failure:{err.rc}") + if start_future and not start_future.done(): + _LOGGER.debug("MQTT error starting session: %s", err) + start_future.set_exception(err) + else: + _LOGGER.debug("MQTT error: %s", err) + if err.rc == MqttReasonCode.RC_ERROR_UNAUTHORIZED and self._unauthorized_hook: + _LOGGER.info("MQTT unauthorized/rate-limit error received, setting backoff to maximum") + self._unauthorized_hook() + self._backoff = MAX_BACKOFF_INTERVAL + raise except MqttError as err: + self._diagnostics.increment("connect_failure:unknown") if start_future and not start_future.done(): _LOGGER.info("MQTT error starting session: %s", err) start_future.set_exception(err) @@ -207,6 +221,7 @@ async def _run_connection(self, start_future: asyncio.Future[None] | None) -> No _LOGGER.info("MQTT error: %s", err) raise except Exception as err: + self._diagnostics.increment("connect_failure:uncaught") # This error is thrown when the MQTT loop is cancelled # and the generator is not stopped. if "generator didn't stop" in str(err) or "generator didn't yield" in str(err): diff --git a/roborock/mqtt/session.py b/roborock/mqtt/session.py index 20d54a42..9d8b57ad 100644 --- a/roborock/mqtt/session.py +++ b/roborock/mqtt/session.py @@ -10,6 +10,8 @@ DEFAULT_TIMEOUT = 30.0 +SessionUnauthorizedHook = Callable[[], None] + @dataclass class MqttParams: @@ -41,6 +43,14 @@ class MqttParams: shared MQTT session diagnostics are included in the overall diagnostics. """ + unauthorized_hook: SessionUnauthorizedHook | None = None + """Optional hook invoked when an unauthorized error is received. + + This may be invoked by the background reconnect logic when an + unauthorized error is received from the broker. The caller may use + this hook to refresh credentials or take other actions as needed. + """ + class MqttSession(ABC): """An MQTT session for sending and receiving messages.""" diff --git a/tests/mqtt/test_roborock_session.py b/tests/mqtt/test_roborock_session.py index 24e6fbbd..15526b66 100644 --- a/tests/mqtt/test_roborock_session.py +++ b/tests/mqtt/test_roborock_session.py @@ -4,6 +4,7 @@ import copy import datetime from collections.abc import Callable, Generator +from typing import Any from unittest.mock import AsyncMock, Mock, patch import aiomqtt @@ -42,20 +43,65 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: """Automatically use the fast backoff fixture.""" -@pytest.fixture(name="mqtt_client_lite") -def mqtt_client_lite_fixture() -> Generator[AsyncMock, None, None]: +class FakeAsyncIterator: + """Fake async iterator that waits for messages to arrive, but they never do. + + This is used for testing exceptions in other client functions. + """ + + def __init__(self) -> None: + self.loop = True + + def __aiter__(self): + return self + + async def __anext__(self) -> None: + """Iterator that does not generate any messages.""" + while self.loop: + await asyncio.sleep(0.01) + + +@pytest.fixture(name="message_iterator") +def message_iterator_fixture() -> FakeAsyncIterator: + """Fixture to provide a side effect for creating the MQTT client.""" + return FakeAsyncIterator() + + +@pytest.fixture(name="mock_client") +def mock_client_fixture(message_iterator: FakeAsyncIterator) -> Generator[AsyncMock, None, None]: """A fixture that provides a mocked aiomqtt Client. This is lighter weight that `mock_aiomqtt_client` that uses real sockets. """ mock_client = AsyncMock() - mock_client.messages = FakeAsyncIterator() + mock_client.messages = message_iterator + return mock_client + + +@pytest.fixture(name="create_client_side_effect") +def create_client_side_effect_fixture() -> Exception | None: + """Fixture to provide a side effect for creating the MQTT client.""" + return None + +@pytest.fixture(name="mock_aenter_client") +def mock_aenter_client_fixture(mock_client: AsyncMock, create_client_side_effect: Exception | None) -> AsyncMock: + """Fixture to provide a side effect for creating the MQTT client.""" mock_aenter = AsyncMock() mock_aenter.return_value = mock_client + mock_aenter.side_effect = create_client_side_effect + return mock_aenter + + +@pytest.fixture(name="mqtt_client_lite") +def mqtt_client_lite_fixture( + mock_client: AsyncMock, + mock_aenter_client: AsyncMock, +) -> Generator[AsyncMock, None, None]: + """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" mock_shim = Mock() - mock_shim.return_value.__aenter__ = mock_aenter + mock_shim.return_value.__aenter__ = mock_aenter_client mock_shim.return_value.__aexit__ = AsyncMock() with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim): @@ -128,21 +174,6 @@ async def test_publish_command(push_mqtt_response: Callable[[bytes], None]) -> N assert not session.connected -class FakeAsyncIterator: - """Fake async iterator that waits for messages to arrive, but they never do. - - This is used for testing exceptions in other client functions. - """ - - def __aiter__(self): - return self - - async def __anext__(self) -> None: - """Iterator that does not generate any messages.""" - while True: - await asyncio.sleep(1) - - async def test_publish_failure(mqtt_client_lite: AsyncMock) -> None: """Test an MQTT error is received when publishing a message.""" @@ -446,3 +477,64 @@ async def test_diagnostics_data(push_mqtt_response: Callable[[bytes], None]) -> assert data.get("subscribe_count") == 2 assert data.get("dispatch_message_count") == 3 assert data.get("close") == 1 + + +@pytest.mark.parametrize( + ("create_client_side_effect"), + [ + # Unauthorized + aiomqtt.MqttCodeError(rc=135), + ], +) +async def test_session_unauthorized_hook(mqtt_client_lite: AsyncMock) -> None: + """Test the MQTT session.""" + + unauthorized = asyncio.Event() + + params = copy.deepcopy(FAKE_PARAMS) + params.unauthorized_hook = unauthorized.set + + with pytest.raises(MqttSessionUnauthorized): + await create_mqtt_session(params) + + assert unauthorized.is_set() + + +async def test_session_unauthorized_after_start( + mock_aenter_client: AsyncMock, + message_iterator: FakeAsyncIterator, + mqtt_client_lite: AsyncMock, + push_mqtt_response: Callable[[bytes], None], +) -> None: + """Test the MQTT session.""" + + # Configure a hook that is notified of unauthorized errors + unauthorized = asyncio.Event() + params = copy.deepcopy(FAKE_PARAMS) + params.unauthorized_hook = unauthorized.set + + # The client will succeed on first connection attempt, then fail with + # unauthorized messages on all future attempts. + request_count = 0 + + def succeed_then_fail_unauthorized() -> Any: + nonlocal request_count + request_count += 1 + if request_count == 1: + return mqtt_client_lite + raise aiomqtt.MqttCodeError(rc=135) + + mock_aenter_client.side_effect = succeed_then_fail_unauthorized + # Don't produce messages, just exit and restart to reconnect + message_iterator.loop = False + + push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2)) + + session = await create_mqtt_session(params) + assert session.connected + + try: + async with asyncio.timeout(10): + assert await unauthorized.wait() + finally: + await session.close()