From 5421b56932d814e111f7c8a3854d98a45cbd1a2d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 5 Dec 2025 16:52:35 -0500 Subject: [PATCH 1/8] bidi - remove python 3.11+ features --- src/strands/experimental/bidi/_async/__init__.py | 8 ++++++-- src/strands/experimental/bidi/agent/agent.py | 13 ++++++++++--- .../experimental/bidi/_async/test__init__.py | 13 ++++++++----- .../experimental/bidi/models/test_gemini_live.py | 9 ++++----- .../bidi/models/test_openai_realtime.py | 12 ++++++++---- 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..2b97ab1fc 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: @@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: exceptions.append(exception) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + exceptions.append(RuntimeError("failed stop sequence")) + for i in range(1, len(exceptions)): + exceptions[i].__cause__ = exceptions[i - 1] + + raise exceptions[-1] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 360dfe707..bc45920d7 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -387,9 +387,16 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: - inputs_task = task_group.create_task(run_inputs()) - task_group.create_task(run_outputs(inputs_task)) + inputs_task = asyncio.create_task(run_inputs()) + outputs_task = asyncio.create_task(run_outputs(inputs_task)) + + try: + await asyncio.gather(inputs_task, outputs_task) + except (Exception, asyncio.CancelledError): + inputs_task.cancel() + outputs_task.cancel() + await asyncio.gather(inputs_task, outputs_task, return_exceptions=True) + raise finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..3c5339d08 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -1,3 +1,4 @@ +import traceback from unittest.mock import AsyncMock import pytest @@ -10,17 +11,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(RuntimeError, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_tb = "".join(traceback.format_exception(RuntimeError, exc_info.value, exc_info.tb)) + assert "ValueError: stop 2 failed" in tru_tb + assert "ValueError: stop 4 failed" in tru_tb @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index da516d4a0..6543dc4f2 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,8 +13,8 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(RuntimeError, match=r"failed stop sequence"): await model4.stop() @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec): assert formatted_empty == [] - # Tool Result Content Tests @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key assert isinstance(audio_event, BidiAudioStreamEvent) # Should use configured rates, not constants assert audio_event.sample_rate == 48000 # Custom config - assert audio_event.channels == 2 # Custom config + assert audio_event.channels == 2 # Custom config assert audio_event.format == "pcm" await model.stop() @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke assert isinstance(audio_event, BidiAudioStreamEvent) # Should use default rates assert audio_event.sample_rate == 24000 # Default output rate - assert audio_event.channels == 1 # Default channels + assert audio_event.channels == 1 # Default channels assert audio_event.format == "pcm" await model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5c9c0900d..5ab183da2 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -131,7 +131,9 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -154,7 +156,9 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -349,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(RuntimeError, match=r"failed stop sequence"): await model4.stop() @@ -510,7 +514,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): format="pcm", sample_rate=24000, channels=1, - ) + ), ] assert tru_events == exp_events From 5b449b3770af6f4fd02a0098421f6dd631af1ac2 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 5 Dec 2025 17:29:52 -0500 Subject: [PATCH 2/8] run - raise on exception only --- src/strands/experimental/bidi/agent/agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index b1b9491a8..efa05229b 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -395,11 +395,12 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: try: await asyncio.gather(inputs_task, outputs_task) - except (Exception, asyncio.CancelledError): + except (Exception, asyncio.CancelledError) as error: inputs_task.cancel() outputs_task.cancel() await asyncio.gather(inputs_task, outputs_task, return_exceptions=True) - raise + if not isinstance(error, asyncio.CancelledError): + raise finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] From db934acf088c8b3af5c5647bca02ad0350a910a5 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 5 Dec 2025 18:16:34 -0500 Subject: [PATCH 3/8] reraise external cancellations --- src/strands/experimental/bidi/agent/agent.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index efa05229b..7cd48c466 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -399,9 +399,14 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: inputs_task.cancel() outputs_task.cancel() await asyncio.gather(inputs_task, outputs_task, return_exceptions=True) + if not isinstance(error, asyncio.CancelledError): raise + run_task = asyncio.current_task() + if run_task and run_task.cancelling() > 0: # externally cancelled + raise + finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] output_stops = [output.stop for output in outputs if isinstance(output, BidiOutput)] From 75cc8af7210c970a24703d5492a13031f752ae76 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Dec 2025 10:15:19 -0500 Subject: [PATCH 4/8] task group --- .../experimental/bidi/_async/__init__.py | 5 +- .../experimental/bidi/_async/_task_group.py | 61 +++++++++++++++++++ src/strands/experimental/bidi/agent/agent.py | 21 ++----- .../bidi/_async/test_task_group.py | 59 ++++++++++++++++++ 4 files changed, 127 insertions(+), 19 deletions(-) create mode 100644 src/strands/experimental/bidi/_async/_task_group.py create mode 100644 tests/strands/experimental/bidi/_async/test_task_group.py diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 2b97ab1fc..103172c13 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -2,9 +2,10 @@ from typing import Awaitable, Callable +from ._task_group import _TaskGroup from ._task_pool import _TaskPool -__all__ = ["_TaskPool"] +__all__ = ["_TaskGroup", "_TaskPool"] async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: @@ -28,6 +29,6 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: if exceptions: exceptions.append(RuntimeError("failed stop sequence")) for i in range(1, len(exceptions)): - exceptions[i].__cause__ = exceptions[i - 1] + exceptions[i].__context__ = exceptions[i - 1] raise exceptions[-1] diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py new file mode 100644 index 000000000..31a75c667 --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -0,0 +1,61 @@ +"""Manage a group of async tasks. + +This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11. + +- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups +""" + +import asyncio +from typing import Any, Coroutine + + +class _TaskGroup: + """Implementation of asyncio.TaskGroup for use in Python 3.10. + + Attributes: + _tasks: List of tasks in group. + """ + + _tasks: list[asyncio.Task] + + def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create an async task and add to group. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + self._tasks.append(task) + return task + + async def __aenter__(self) -> "_TaskGroup": + """Setup self managed task group context.""" + self._tasks = [] + return self + + async def __aexit__(self, *_: Any) -> None: + """Execute tasks in group. + + The following execution rules are enforced: + - The context stops executing all tasks if at least one task raises an Exception or the context is cancelled. + - The context re-raises Exceptions to the caller. + - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. + """ + try: + await asyncio.gather(*self._tasks) + + except (Exception, asyncio.CancelledError) as error: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + if not isinstance(error, asyncio.CancelledError): + raise + + context_task = asyncio.current_task() + if context_task and context_task.cancelling() > 0: # context itself was cancelled + raise + + finally: + self._tasks = [] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 7cd48c466..5ddb181ea 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,7 +30,7 @@ from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider -from .._async import stop_all +from .._async import _TaskGroup, stop_all from ..models.model import BidiModel from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -390,22 +390,9 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - inputs_task = asyncio.create_task(run_inputs()) - outputs_task = asyncio.create_task(run_outputs(inputs_task)) - - try: - await asyncio.gather(inputs_task, outputs_task) - except (Exception, asyncio.CancelledError) as error: - inputs_task.cancel() - outputs_task.cancel() - await asyncio.gather(inputs_task, outputs_task, return_exceptions=True) - - if not isinstance(error, asyncio.CancelledError): - raise - - run_task = asyncio.current_task() - if run_task and run_task.cancelling() > 0: # externally cancelled - raise + async with _TaskGroup() as task_group: + inputs_task = task_group.create_task(run_inputs()) + task_group.create_task(run_outputs(inputs_task)) finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py new file mode 100644 index 000000000..c2b734534 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -0,0 +1,59 @@ +import asyncio +import unittest.mock + +import pytest + +from strands.experimental.bidi._async._task_group import _TaskGroup + + +@pytest.mark.asyncio +async def test_task_group__aexit__(): + coro = unittest.mock.AsyncMock() + + async with _TaskGroup() as task_group: + task_group.create_task(coro()) + + coro.assert_called_once() + + +@pytest.mark.asyncio +async def test_task_group__aexit__exception(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + async def fail(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + fail_task = task_group.create_task(fail()) + + assert wait_task.cancelled() + assert not fail_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__cancelled(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + tasks = [] + + run_event = asyncio.Event() + async def run(): + async with _TaskGroup() as task_group: + tasks.append(task_group.create_task(wait())) + run_event.set() + + run_task = asyncio.create_task(run()) + await run_event.wait() + run_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await run_task + + wait_task = tasks[0] + assert wait_task.cancelled() From bfc32cce9fbf6dc5d85544962bdb0bbbc36e63b8 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Dec 2025 10:18:57 -0500 Subject: [PATCH 5/8] wording --- src/strands/experimental/bidi/_async/_task_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py index 31a75c667..26c67326d 100644 --- a/src/strands/experimental/bidi/_async/_task_group.py +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -10,7 +10,7 @@ class _TaskGroup: - """Implementation of asyncio.TaskGroup for use in Python 3.10. + """Shim of asyncio.TaskGroup for use in Python 3.10. Attributes: _tasks: List of tasks in group. From 93f1fa1f579b65434973d40054d9417cae0406d9 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Dec 2025 10:20:33 -0500 Subject: [PATCH 6/8] test regex --- tests/strands/experimental/bidi/_async/test_task_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py index c2b734534..23ff821f9 100644 --- a/tests/strands/experimental/bidi/_async/test_task_group.py +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -25,7 +25,7 @@ async def wait(): async def fail(): raise ValueError("test error") - with pytest.raises(ValueError, match="test error"): + with pytest.raises(ValueError, match=r"test error"): async with _TaskGroup() as task_group: wait_task = task_group.create_task(wait()) fail_task = task_group.create_task(fail()) From 2ecbd22c4fbd203eb26b84e3816a91d360935d30 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Dec 2025 10:29:21 -0500 Subject: [PATCH 7/8] bidi exception chain --- .../experimental/bidi/_async/__init__.py | 7 +-- src/strands/experimental/bidi/agent/loop.py | 2 +- src/strands/experimental/bidi/errors.py | 45 +++++++++++++++++++ .../experimental/bidi/models/__init__.py | 3 +- .../experimental/bidi/models/gemini_live.py | 3 +- src/strands/experimental/bidi/models/model.py | 20 --------- .../experimental/bidi/models/nova_sonic.py | 3 +- .../bidi/models/openai_realtime.py | 3 +- src/strands/experimental/bidi/types/events.py | 2 +- src/strands/experimental/hooks/events.py | 2 +- .../experimental/bidi/_async/test__init__.py | 4 +- .../experimental/bidi/agent/test_loop.py | 3 +- .../bidi/models/test_gemini_live.py | 4 +- .../bidi/models/test_nova_sonic.py | 2 +- .../bidi/models/test_openai_realtime.py | 4 +- 15 files changed, 65 insertions(+), 42 deletions(-) create mode 100644 src/strands/experimental/bidi/errors.py diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 103172c13..403960c46 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -2,6 +2,7 @@ from typing import Awaitable, Callable +from ..errors import BidiExceptionChain from ._task_group import _TaskGroup from ._task_pool import _TaskPool @@ -27,8 +28,4 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: exceptions.append(exception) if exceptions: - exceptions.append(RuntimeError("failed stop sequence")) - for i in range(1, len(exceptions)): - exceptions[i].__context__ = exceptions[i - 1] - - raise exceptions[-1] + raise BidiExceptionChain("failed stop sequence", exceptions) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 2b883cf73..7ea9a9a57 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -20,7 +20,7 @@ BidiInterruptionEvent as BidiInterruptionHookEvent, ) from .._async import _TaskPool, stop_all -from ..models import BidiModelTimeoutError +from ..errors import BidiModelTimeoutError from ..types.events import ( BidiConnectionCloseEvent, BidiConnectionRestartEvent, diff --git a/src/strands/experimental/bidi/errors.py b/src/strands/experimental/bidi/errors.py new file mode 100644 index 000000000..34b1618d7 --- /dev/null +++ b/src/strands/experimental/bidi/errors.py @@ -0,0 +1,45 @@ +"""Custom bidi exceptions.""" +from typing import Any + + +class BidiExceptionChain(Exception): + """Chain a list of exceptions together. + + Useful for chaining together exceptions raised across a multi-step workflow (e.g., the bidi `stop` methods). + Note, this exception is meant to mimic ExceptionGroup released in Python 3.11. + + - Docs: https://docs.python.org/3/library/exceptions.html#ExceptionGroup + """ + + def __init__(self, message: str, exceptions: list[Exception]) -> None: + """Chain exceptions. + + Args: + message: Top-level exception message. + exceptions: List of exceptions to chain. + """ + super().__init__(self, message) + + exceptions.append(self) + for i in range(1, len(exceptions)): + exceptions[i].__context__ = exceptions[i - 1] + + +class BidiModelTimeoutError(Exception): + """Model timeout error. + + Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection + open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as + to create a seamless, uninterrupted experience for the user. + """ + + def __init__(self, message: str, **restart_config: Any) -> None: + """Initialize error. + + Args: + message: Timeout message from model. + **restart_config: Configure restart specific behaviors in the call to model start. + """ + super().__init__(self, message) + + self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index cc62c9987..db661e282 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,10 +1,9 @@ """Bidirectional model interfaces and implementations.""" -from .model import BidiModel, BidiModelTimeoutError +from .model import BidiModel from .nova_sonic import BidiNovaSonicModel __all__ = [ "BidiModel", - "BidiModelTimeoutError", "BidiNovaSonicModel", ] diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 88d7f5a0c..5ae8ee7c1 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -25,6 +25,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all +from ..errors import BidiModelTimeoutError from ..types.events import ( AudioChannel, AudioSampleRate, @@ -41,7 +42,7 @@ ModalityUsage, ) from ..types.model import AudioConfig -from .model import BidiModel, BidiModelTimeoutError +from .model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py index f5e34aa50..e2584b1c5 100644 --- a/src/strands/experimental/bidi/models/model.py +++ b/src/strands/experimental/bidi/models/model.py @@ -112,23 +112,3 @@ async def send( ``` """ ... - - -class BidiModelTimeoutError(Exception): - """Model timeout error. - - Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection - open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as - to create a seamless, uninterrupted experience for the user. - """ - - def __init__(self, message: str, **restart_config: Any) -> None: - """Initialize error. - - Args: - message: Timeout message from model. - **restart_config: Configure restart specific behaviors in the call to model start. - """ - super().__init__(self, message) - - self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 6a2477e22..8e5959776 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -37,6 +37,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all +from ..errors import BidiModelTimeoutError from ..types.events import ( AudioChannel, AudioSampleRate, @@ -53,7 +54,7 @@ BidiUsageEvent, ) from ..types.model import AudioConfig -from .model import BidiModel, BidiModelTimeoutError +from .model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py index 9196a39d5..300825cdb 100644 --- a/src/strands/experimental/bidi/models/openai_realtime.py +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -19,6 +19,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all +from ..errors import BidiModelTimeoutError from ..types.events import ( AudioSampleRate, BidiAudioInputEvent, @@ -37,7 +38,7 @@ StopReason, ) from ..types.model import AudioConfig -from .model import BidiModel, BidiModelTimeoutError +from .model import BidiModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 9d44fc660..81236e122 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -27,7 +27,7 @@ from ....types.streaming import ContentBlockDelta if TYPE_CHECKING: - from ..models.model import BidiModelTimeoutError + from ..errors import BidiModelTimeoutError AudioChannel = Literal[1, 2] """Number of audio channels. diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index 8a8d80629..40a90612f 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from ..bidi.agent.agent import BidiAgent - from ..bidi.models import BidiModelTimeoutError + from ..bidi.errors import BidiModelTimeoutError warnings.warn( "BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent are no longer experimental." diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index 3c5339d08..5af1790c2 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -13,7 +13,7 @@ async def test_stop_exception(): func3 = AsyncMock() func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(RuntimeError, match=r"failed stop sequence") as exc_info: + with pytest.raises(Exception, match=r"failed stop sequence") as exc_info: await stop_all(func1, func2, func3, func4) func1.assert_called_once() @@ -21,7 +21,7 @@ async def test_stop_exception(): func3.assert_called_once() func4.assert_called_once() - tru_tb = "".join(traceback.format_exception(RuntimeError, exc_info.value, exc_info.tb)) + tru_tb = "".join(traceback.format_exception(Exception, exc_info.value, exc_info.tb)) assert "ValueError: stop 2 failed" in tru_tb assert "ValueError: stop 4 failed" in tru_tb diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 0ce8d6658..9006dae13 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -5,8 +5,7 @@ from strands import tool from strands.experimental.bidi import BidiAgent -from strands.experimental.bidi.agent.loop import _BidiAgentLoop -from strands.experimental.bidi.models import BidiModelTimeoutError +from strands.experimental.bidi.errors import BidiModelTimeoutError from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index 6543dc4f2..62f077e04 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,8 +13,8 @@ import pytest from google.genai import types as genai_types +from strands.experimental.bidi.errors import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(RuntimeError, match=r"failed stop sequence"): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 04f8043be..002af28cf 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -16,7 +16,7 @@ from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, ) -from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.errors import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5ab183da2..d25340d6a 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -14,7 +14,7 @@ import pytest -from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.errors import BidiModelTimeoutError from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -353,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(RuntimeError, match=r"failed stop sequence"): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() From 0279685955b68065e94637f92e348abfd6c3532e Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Dec 2025 10:42:02 -0500 Subject: [PATCH 8/8] docs --- src/strands/experimental/bidi/_async/__init__.py | 2 +- src/strands/experimental/bidi/errors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 403960c46..218407feb 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -18,7 +18,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - RuntimeError: If any stop function raises an exception. + BidiExceptionChain: If any stop function raises an exception. """ exceptions = [] for func in funcs: diff --git a/src/strands/experimental/bidi/errors.py b/src/strands/experimental/bidi/errors.py index 34b1618d7..8deb0dbef 100644 --- a/src/strands/experimental/bidi/errors.py +++ b/src/strands/experimental/bidi/errors.py @@ -1,4 +1,4 @@ -"""Custom bidi exceptions.""" +"""Custom bidi error classes.""" from typing import Any