diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..47473115c 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -2,9 +2,10 @@ from typing import Awaitable, Callable +from ._task_group import _TaskGroup from ._task_pool import _TaskPool -__all__ = ["_TaskPool"] +__all__ = ["_TaskGroup", "_TaskPool"] async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: @@ -16,14 +17,14 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: try: await func() except Exception as exception: - exceptions.append(exception) + exceptions.append({"func_name": func.__name__, "exception": repr(exception)}) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + raise RuntimeError(f"exceptions={exceptions} | failed stop sequence") diff --git a/src/strands/experimental/bidi/_async/_task_group.py b/src/strands/experimental/bidi/_async/_task_group.py new file mode 100644 index 000000000..26c67326d --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_group.py @@ -0,0 +1,61 @@ +"""Manage a group of async tasks. + +This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11. + +- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups +""" + +import asyncio +from typing import Any, Coroutine + + +class _TaskGroup: + """Shim of asyncio.TaskGroup for use in Python 3.10. + + Attributes: + _tasks: List of tasks in group. + """ + + _tasks: list[asyncio.Task] + + def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create an async task and add to group. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + self._tasks.append(task) + return task + + async def __aenter__(self) -> "_TaskGroup": + """Setup self managed task group context.""" + self._tasks = [] + return self + + async def __aexit__(self, *_: Any) -> None: + """Execute tasks in group. + + The following execution rules are enforced: + - The context stops executing all tasks if at least one task raises an Exception or the context is cancelled. + - The context re-raises Exceptions to the caller. + - The context re-raises CancelledErrors to the caller only if the context itself was cancelled. + """ + try: + await asyncio.gather(*self._tasks) + + except (Exception, asyncio.CancelledError) as error: + for task in self._tasks: + task.cancel() + + await asyncio.gather(*self._tasks, return_exceptions=True) + + if not isinstance(error, asyncio.CancelledError): + raise + + context_task = asyncio.current_task() + if context_task and context_task.cancelling() > 0: # context itself was cancelled + raise + + finally: + self._tasks = [] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 4012d5e2d..5ddb181ea 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,7 +30,7 @@ from ....types.tools import AgentTool from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider -from .._async import stop_all +from .._async import _TaskGroup, stop_all from ..models.model import BidiModel from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput @@ -390,7 +390,7 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: + async with _TaskGroup() as task_group: inputs_task = task_group.create_task(run_inputs()) task_group.create_task(run_outputs(inputs_task)) diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..a121ddecc 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -10,17 +10,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(Exception, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_message = str(exc_info.value) + assert "ValueError('stop 2 failed')" in tru_message + assert "ValueError('stop 4 failed')" in tru_message @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/_async/test_task_group.py b/tests/strands/experimental/bidi/_async/test_task_group.py new file mode 100644 index 000000000..23ff821f9 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_group.py @@ -0,0 +1,59 @@ +import asyncio +import unittest.mock + +import pytest + +from strands.experimental.bidi._async._task_group import _TaskGroup + + +@pytest.mark.asyncio +async def test_task_group__aexit__(): + coro = unittest.mock.AsyncMock() + + async with _TaskGroup() as task_group: + task_group.create_task(coro()) + + coro.assert_called_once() + + +@pytest.mark.asyncio +async def test_task_group__aexit__exception(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + async def fail(): + raise ValueError("test error") + + with pytest.raises(ValueError, match=r"test error"): + async with _TaskGroup() as task_group: + wait_task = task_group.create_task(wait()) + fail_task = task_group.create_task(fail()) + + assert wait_task.cancelled() + assert not fail_task.cancelled() + + +@pytest.mark.asyncio +async def test_task_group__aexit__cancelled(): + wait_event = asyncio.Event() + async def wait(): + await wait_event.wait() + + tasks = [] + + run_event = asyncio.Event() + async def run(): + async with _TaskGroup() as task_group: + tasks.append(task_group.create_task(wait())) + run_event.set() + + run_task = asyncio.create_task(run()) + await run_event.wait() + run_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await run_task + + wait_task = tasks[0] + assert wait_task.cancelled() diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index da516d4a0..d6ffedb37 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,7 +13,7 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.models import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec): assert formatted_empty == [] - # Tool Result Content Tests @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key assert isinstance(audio_event, BidiAudioStreamEvent) # Should use configured rates, not constants assert audio_event.sample_rate == 48000 # Custom config - assert audio_event.channels == 2 # Custom config + assert audio_event.channels == 2 # Custom config assert audio_event.format == "pcm" await model.stop() @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke assert isinstance(audio_event, BidiAudioStreamEvent) # Should use default rates assert audio_event.sample_rate == 24000 # Default output rate - assert audio_event.channels == 1 # Default channels + assert audio_event.channels == 1 # Default channels assert audio_event.format == "pcm" await model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5c9c0900d..df6dff0f1 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -14,7 +14,7 @@ import pytest -from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.models import BidiModelTimeoutError from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -131,7 +131,9 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -154,7 +156,9 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -349,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(Exception, match=r"failed stop sequence"): await model4.stop() @@ -510,7 +514,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): format="pcm", sample_rate=24000, channels=1, - ) + ), ] assert tru_events == exp_events