diff --git a/examples/ai-test/src/main.py b/examples/ai-test/src/main.py index 6e6e6039..7b5d67d7 100644 --- a/examples/ai-test/src/main.py +++ b/examples/ai-test/src/main.py @@ -4,6 +4,7 @@ """ import asyncio +import logging import re from os import getenv @@ -30,6 +31,8 @@ load_dotenv(find_dotenv(usecwd=True)) +logging.basicConfig(level=getenv("LOG_LEVEL", "WARNING").upper()) + def get_required_env(key: str) -> str: value = getenv(key) diff --git a/packages/apps/src/microsoft_teams/apps/app_process.py b/packages/apps/src/microsoft_teams/apps/app_process.py index 9b048113..fdea6a84 100644 --- a/packages/apps/src/microsoft_teams/apps/app_process.py +++ b/packages/apps/src/microsoft_teams/apps/app_process.py @@ -27,7 +27,7 @@ from .activity_sender import ActivitySender from .events import ActivityEvent, ActivityResponseEvent, ActivitySentEvent, ErrorEvent -from .plugins import PluginActivityEvent, PluginBase +from .plugins import PluginActivityEvent, PluginBase, StreamCancelledError from .routing.activity_context import ActivityContext from .routing.router import ActivityHandler, ActivityRouter from .token_manager import TokenManager @@ -224,6 +224,10 @@ async def route(ctx: ActivityContext[ActivityBase]) -> Optional[Any]: ), plugins=plugins, ) + except StreamCancelledError: + logger.debug("Activity processing was cancelled (stream stopped)") + await activityCtx.stream.close() + response = InvokeResponse[Any](status=200) except Exception as error: await self.event_manager.on_error(ErrorEvent(error=error, activity=activity), plugins) raise error diff --git a/packages/apps/src/microsoft_teams/apps/http_stream.py b/packages/apps/src/microsoft_teams/apps/http_stream.py index 645025c8..65b2b93c 100644 --- a/packages/apps/src/microsoft_teams/apps/http_stream.py +++ b/packages/apps/src/microsoft_teams/apps/http_stream.py @@ -8,6 +8,7 @@ from collections import deque from typing import Awaitable, Callable, List, Optional, Union +from httpx import HTTPStatusError from microsoft_teams.api import ( ApiClient, Attachment, @@ -20,7 +21,7 @@ ) from microsoft_teams.common import EventEmitter -from .plugins.streamer import StreamerEvent, StreamerProtocol +from .plugins.streamer import StreamCancelledError, StreamerEvent, StreamerProtocol from .utils import RetryOptions, retry logger = logging.getLogger(__name__) @@ -62,6 +63,7 @@ def __init__(self, client: ApiClient, ref: ConversationReference): self._total_wait_timeout: float = 30.0 self._state_changed = asyncio.Event() + self._canceled = False self._reset_state() def _reset_state(self) -> None: @@ -74,6 +76,14 @@ def _reset_state(self) -> None: self._entities: List[Entity] = [] self._queue: deque[Union[MessageActivityInput, TypingActivityInput, str]] = deque() + @property + def canceled(self) -> bool: + """ + Whether the stream has been canceled. + For example when the user pressed the Stop button or the 2-minute timeout has exceeded. + """ + return self._canceled + @property def closed(self) -> bool: """Whether the final stream message has been sent.""" @@ -103,6 +113,9 @@ def emit(self, activity: Union[MessageActivityInput, TypingActivityInput, str]) activity: The activity to emit. """ + if self._canceled: + raise StreamCancelledError("Stream has been cancelled.") + if isinstance(activity, str): activity = MessageActivityInput(text=activity, type="message") self._queue.append(activity) @@ -124,7 +137,7 @@ async def _wait_for_id_and_queue(self): """Wait until _id is set and the queue is empty, with a total timeout.""" async def _poll(): - while self._queue or not self._id: + while (self._queue or not self._id) and not self._canceled: await self._state_changed.wait() self._state_changed.clear() @@ -140,6 +153,10 @@ async def close(self) -> Optional[SentActivity]: logger.debug("stream already closed with result") return self._result + if self._canceled: + logger.debug("stream was cancelled, nothing to close") + return None + if self._index == 1 and not self._queue and not self._lock.locked(): logger.debug("stream has no content to send, returning None") return None @@ -229,13 +246,11 @@ async def _flush(self) -> None: if self._queue and not self._timeout: self._timeout = asyncio.get_running_loop().call_later(0.5, lambda: asyncio.create_task(self._flush())) - # Notify that queue state has changed - self._state_changed.set() - finally: # Reset flushing flag so future emits can trigger another flush self._pending = None self._lock.release() + self._state_changed.set() async def _send_activity(self, to_send: TypingActivityInput): """ @@ -265,12 +280,23 @@ async def _send(self, to_send: Union[TypingActivityInput, MessageActivityInput]) Args: activity: The activity to send. """ + if self._canceled: + logger.warning("Teams channel stopped the stream.") + raise StreamCancelledError("Teams channel stopped the stream.") + to_send.from_ = self._ref.bot to_send.conversation = self._ref.conversation - if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])): - res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send) - else: - res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send) - - return SentActivity.merge(to_send, res) + try: + if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])): + res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send) + else: + res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send) + + return SentActivity.merge(to_send, res) + except HTTPStatusError as e: + if e.response.status_code == 403: + self._canceled = True + logger.warning("Teams channel stopped the stream.") + raise StreamCancelledError("Teams channel stopped the stream.") from e + raise diff --git a/packages/apps/src/microsoft_teams/apps/plugins/__init__.py b/packages/apps/src/microsoft_teams/apps/plugins/__init__.py index e78a2de0..b489ab44 100644 --- a/packages/apps/src/microsoft_teams/apps/plugins/__init__.py +++ b/packages/apps/src/microsoft_teams/apps/plugins/__init__.py @@ -10,10 +10,11 @@ from .plugin_base import PluginBase from .plugin_error_event import PluginErrorEvent from .plugin_start_event import PluginStartEvent -from .streamer import StreamerProtocol +from .streamer import StreamCancelledError, StreamerProtocol __all__ = [ "PluginBase", + "StreamCancelledError", "StreamerProtocol", "PluginActivityEvent", "PluginActivityResponseEvent", diff --git a/packages/apps/src/microsoft_teams/apps/plugins/streamer.py b/packages/apps/src/microsoft_teams/apps/plugins/streamer.py index c73b6fa9..6a3e7b7c 100644 --- a/packages/apps/src/microsoft_teams/apps/plugins/streamer.py +++ b/packages/apps/src/microsoft_teams/apps/plugins/streamer.py @@ -3,6 +3,7 @@ Licensed under the MIT License. """ +import asyncio from typing import Awaitable, Callable, Literal, Optional, Protocol, Union from microsoft_teams.api import MessageActivityInput, SentActivity, TypingActivityInput @@ -10,9 +11,23 @@ StreamerEvent = Literal["chunk", "close"] +class StreamCancelledError(asyncio.CancelledError): + """Raised when a stream operation is attempted after the stream has been cancelled.""" + + pass + + class StreamerProtocol(Protocol): """Component that can send streamed chunks of an activity.""" + @property + def canceled(self) -> bool: + """ + Whether the stream has been canceled. + For example when the user pressed the Stop button or the 2-minute timeout has exceeded. + """ + ... + @property def closed(self) -> bool: """Whether the final stream message has been sent.""" diff --git a/packages/apps/tests/test_http_stream.py b/packages/apps/tests/test_http_stream.py index f6c9905a..50cb4d2d 100644 --- a/packages/apps/tests/test_http_stream.py +++ b/packages/apps/tests/test_http_stream.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest +from httpx import HTTPStatusError, Request, Response from microsoft_teams.api import ( Account, ApiClient, @@ -17,6 +18,7 @@ TypingActivityInput, ) from microsoft_teams.apps import HttpStream +from microsoft_teams.apps.plugins import StreamCancelledError class TestHttpStream: @@ -219,3 +221,110 @@ async def emit_task(): await self._run_scheduled_flushes(scheduled) assert max_concurrent_entries == 1 + + @pytest.mark.asyncio + async def test_stream_canceled_on_403(self, mock_api_client, conversation_reference, patch_loop_call_later): + loop = asyncio.get_running_loop() + patcher, scheduled = patch_loop_call_later(loop) + with patcher: + + async def mock_send_403(activity): + raise HTTPStatusError( + "Forbidden", + request=Request("POST", "https://example.com"), + response=Response(403), + ) + + mock_api_client.conversations.activities().create = mock_send_403 + stream = HttpStream(mock_api_client, conversation_reference) + + stream.emit("Test message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is True + + @pytest.mark.asyncio + async def test_emit_blocked_after_cancel(self, mock_api_client, conversation_reference, patch_loop_call_later): + loop = asyncio.get_running_loop() + patcher, scheduled = patch_loop_call_later(loop) + with patcher: + + async def mock_send_403(activity): + raise HTTPStatusError( + "Forbidden", + request=Request("POST", "https://example.com"), + response=Response(403), + ) + + mock_api_client.conversations.activities().create = mock_send_403 + stream = HttpStream(mock_api_client, conversation_reference) + + stream.emit("First message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is True + + # Emit after cancel should raise + with pytest.raises(StreamCancelledError, match="Stream has been cancelled."): + stream.emit("Should be ignored") + + @pytest.mark.asyncio + async def test_send_blocked_after_cancel(self, mock_api_client, conversation_reference): + stream = HttpStream(mock_api_client, conversation_reference) + stream._canceled = True + + with pytest.raises(StreamCancelledError, match="Teams channel stopped the stream."): + await stream._send(TypingActivityInput(text="test")) + + @pytest.mark.asyncio + async def test_stream_canceled_after_successful_message( + self, mock_api_client, conversation_reference, patch_loop_call_later + ): + call_count = 0 + loop = asyncio.get_running_loop() + patcher, scheduled = patch_loop_call_later(loop) + with patcher: + + async def mock_send_then_403(activity): + nonlocal call_count + call_count += 1 + if call_count == 1: + return SentActivity(id="activity-1", activity_params=activity) + raise HTTPStatusError( + "Forbidden", + request=Request("POST", "https://example.com"), + response=Response(403), + ) + + mock_api_client.conversations.activities().create = mock_send_then_403 + stream = HttpStream(mock_api_client, conversation_reference) + + # First emit succeeds + stream.emit("First message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is False + assert call_count == 1 + + # Second emit triggers 403 + stream.emit("Second message") + await asyncio.sleep(0) + await self._run_scheduled_flushes(scheduled) + + assert stream.canceled is True + assert call_count == 2 + + # Further emits raise + with pytest.raises(StreamCancelledError, match="Stream has been cancelled."): + stream.emit("Should be ignored") + + @pytest.mark.asyncio + async def test_close_returns_none_when_canceled(self, mock_api_client, conversation_reference): + stream = HttpStream(mock_api_client, conversation_reference) + stream._canceled = True + + result = await stream.close() + assert result is None