-
Notifications
You must be signed in to change notification settings - Fork 3.3k
feat: add RobustMicrophone to handle hardware disconnects (#6076) #6183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| import asyncio | ||
| import time | ||
| from typing import Any | ||
|
|
||
| from livekit import rtc | ||
| from livekit.agents.log import logger | ||
| from livekit.agents.utils import aio | ||
|
|
||
|
|
||
| class RobustMicrophone: | ||
| """ | ||
| A robust microphone capture utility that wraps rtc.MediaDevices().open_input() | ||
| and automatically restarts the audio stream if it stalls (e.g. if the microphone | ||
| cable becomes loose). | ||
|
|
||
| This class exposes an `rtc.AudioSource` as `.source` which you can use to | ||
| create your LocalAudioTrack. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| sample_rate: int = 48000, | ||
| num_channels: int = 1, | ||
| stall_timeout: float = 2.0, | ||
| **open_input_kwargs: Any, | ||
| ) -> None: | ||
| self._sample_rate = sample_rate | ||
| self._num_channels = num_channels | ||
| self._stall_timeout = stall_timeout | ||
| self._kwargs = open_input_kwargs | ||
|
|
||
| self._devices = rtc.MediaDevices() | ||
| self._source = rtc.AudioSource(self._sample_rate, self._num_channels) | ||
|
|
||
| self._mic_obj: Any = None | ||
| self._mic_track: rtc.LocalAudioTrack | None = None | ||
| self._mic_stream: rtc.AudioStream | None = None | ||
|
|
||
| self._running = False | ||
| self._monitor_task: asyncio.Task[None] | None = None | ||
| self._last_frame_time: float = time.monotonic() | ||
|
|
||
| @property | ||
| def source(self) -> rtc.AudioSource: | ||
| return self._source | ||
|
|
||
| def start(self) -> None: | ||
| if self._running: | ||
| return | ||
| self._running = True | ||
| self._monitor_task = asyncio.create_task(self._monitor_loop()) | ||
|
|
||
| async def aclose(self) -> None: | ||
| if not self._running: | ||
| return | ||
| self._running = False | ||
| if self._monitor_task: | ||
| await aio.cancel_and_wait(self._monitor_task) | ||
| await self._close_internal() | ||
|
|
||
| async def _close_internal(self) -> None: | ||
| if self._mic_stream: | ||
| await self._mic_stream.aclose() | ||
| self._mic_stream = None | ||
| self._mic_track = None | ||
| self._mic_obj = None | ||
|
|
||
| def _start_internal(self) -> None: | ||
| # Provide defaults for sample_rate and num_channels if not provided by user | ||
| kwargs = dict(self._kwargs) | ||
| if "sample_rate" not in kwargs: | ||
| kwargs["sample_rate"] = self._sample_rate | ||
| if "num_channels" not in kwargs: | ||
| kwargs["num_channels"] = self._num_channels | ||
|
|
||
| self._mic_obj = self._devices.open_input(**kwargs) | ||
| self._mic_track = rtc.LocalAudioTrack.create_audio_track("robust-mic-internal", self._mic_obj.source) | ||
| self._mic_stream = rtc.AudioStream.from_track(self._mic_track) | ||
| self._last_frame_time = time.monotonic() | ||
|
|
||
| async def _monitor_loop(self) -> None: | ||
| self._start_internal() | ||
|
|
||
| while self._running: | ||
| try: | ||
| assert self._mic_stream is not None | ||
|
|
||
| # Wait for the next audio event with a timeout | ||
| event = await asyncio.wait_for( | ||
| self._mic_stream.__anext__(), timeout=self._stall_timeout | ||
| ) | ||
| self._last_frame_time = time.monotonic() | ||
|
|
||
| # Forward the captured frame to our own source | ||
| await self._source.capture_frame(event.frame) | ||
|
|
||
| except asyncio.TimeoutError: | ||
| # Stall detected! | ||
| logger.warning( | ||
| f"RobustMicrophone: No audio frames received for {self._stall_timeout}s. Restarting microphone..." | ||
| ) | ||
| await self._close_internal() | ||
| await asyncio.sleep(0.5) # Brief pause before reconnecting | ||
| self._start_internal() | ||
| except Exception as e: | ||
| logger.error(f"RobustMicrophone error in monitor loop: {e}", exc_info=e) | ||
| await asyncio.sleep(1.0) | ||
| await self._close_internal() | ||
| self._start_internal() | ||
|
Comment on lines
+98
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Unprotected In Affected code pathsLine 105 (after timeout stall detection): except asyncio.TimeoutError:
await self._close_internal()
await asyncio.sleep(0.5)
self._start_internal() # unprotected - crashes loop on failureLine 110 (after generic error): except Exception as e:
await asyncio.sleep(1.0)
await self._close_internal()
self._start_internal() # unprotected - crashes loop on failureThe same issue exists at line 83 where the initial Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| import asyncio | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| import pytest | ||
| from livekit import rtc | ||
| from livekit.agents.utils.robust_microphone import RobustMicrophone | ||
|
|
||
| pytestmark = pytest.mark.unit | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_robust_microphone_startup_and_shutdown(): | ||
| with patch("livekit.rtc.MediaDevices", autospec=True) as mock_media_devices, \ | ||
| patch("livekit.rtc.AudioSource", autospec=True) as mock_audio_source, \ | ||
|
Check failure on line 14 in tests/test_robust_microphone.py
|
||
| patch("livekit.rtc.LocalAudioTrack.create_audio_track", autospec=True) as mock_create_track, \ | ||
| patch("livekit.rtc.AudioStream.from_track", autospec=True) as mock_from_track: | ||
|
|
||
| # Mock the stream iterator to just hang (we'll shut down before it times out) | ||
| mock_stream = AsyncMock() | ||
| mock_stream.__anext__.side_effect = asyncio.TimeoutError | ||
| mock_from_track.return_value = mock_stream | ||
|
|
||
| mic = RobustMicrophone(stall_timeout=10.0) | ||
| mic.start() | ||
|
|
||
| # Give it a moment to start the internal stream | ||
| await asyncio.sleep(0.1) | ||
|
|
||
| assert mock_media_devices.called | ||
| assert mock_create_track.called | ||
| assert mock_from_track.called | ||
|
|
||
| await mic.aclose() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_robust_microphone_restarts_on_stall(): | ||
| with patch("livekit.rtc.MediaDevices", autospec=True) as mock_media_devices, \ | ||
|
Check failure on line 38 in tests/test_robust_microphone.py
|
||
| patch("livekit.rtc.AudioSource", autospec=True) as mock_audio_source, \ | ||
|
Check failure on line 39 in tests/test_robust_microphone.py
|
||
| patch("livekit.rtc.LocalAudioTrack.create_audio_track", autospec=True) as mock_create_track, \ | ||
|
Check failure on line 40 in tests/test_robust_microphone.py
|
||
| patch("livekit.rtc.AudioStream.from_track", autospec=True) as mock_from_track: | ||
|
|
||
| mock_stream = AsyncMock() | ||
| mock_stream.aclose = AsyncMock() | ||
|
|
||
| # First call times out immediately, causing a restart | ||
| # Second call hangs so we can shut down | ||
| call_count = 0 | ||
| async def mock_anext(): | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| if call_count == 1: | ||
| raise asyncio.TimeoutError() | ||
| else: | ||
| await asyncio.sleep(10) | ||
| return MagicMock() | ||
|
|
||
| mock_stream.__anext__ = mock_anext | ||
| mock_from_track.return_value = mock_stream | ||
|
|
||
| mic = RobustMicrophone(stall_timeout=0.1) | ||
| mic.start() | ||
|
|
||
| # Wait for the timeout and restart to happen | ||
| await asyncio.sleep(0.5) | ||
|
|
||
| # MediaDevices.open_input should have been called twice (initial + restart) | ||
| assert mock_from_track.call_count >= 2 | ||
|
|
||
| await mic.aclose() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚩
AudioStream.from_trackcalled withoutsample_rate/num_channelsunlike codebase patternThe existing usage of
rtc.AudioStream.from_trackinlivekit-agents/livekit/agents/voice/room_io/_input.py:303-308passessample_rate,num_channels, andframe_size_msexplicitly. The new code at line 79 callsrtc.AudioStream.from_track(self._mic_track)without these parameters. If thefrom_trackdefaults differ fromself._sample_rate/self._num_channels, the frames forwarded toself._sourcecould have mismatched format. This might be benign if the local mic track already outputs at the configured rate, but it deviates from established patterns and could cause subtle audio issues.Was this helpful? React with 👍 or 👎 to provide feedback.