Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/ai-test/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import asyncio
import logging
import re
from os import getenv

Expand All @@ -30,6 +31,8 @@

load_dotenv(find_dotenv(usecwd=True))

logging.basicConfig(level=getenv("LOG_LEVEL", "WARNING").upper())


def get_required_env(key: str) -> str:
value = getenv(key)
Expand Down
6 changes: 5 additions & 1 deletion packages/apps/src/microsoft_teams/apps/app_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from .activity_sender import ActivitySender
from .events import ActivityEvent, ActivityResponseEvent, ActivitySentEvent, ErrorEvent
from .plugins import PluginActivityEvent, PluginBase
from .plugins import PluginActivityEvent, PluginBase, StreamCancelledError
from .routing.activity_context import ActivityContext
from .routing.router import ActivityHandler, ActivityRouter
from .token_manager import TokenManager
Expand Down Expand Up @@ -224,6 +224,10 @@ async def route(ctx: ActivityContext[ActivityBase]) -> Optional[Any]:
),
plugins=plugins,
)
except StreamCancelledError:
logger.debug("Activity processing was cancelled (stream stopped)")
await activityCtx.stream.close()
response = InvokeResponse[Any](status=200)
except Exception as error:
await self.event_manager.on_error(ErrorEvent(error=error, activity=activity), plugins)
raise error
Expand Down
48 changes: 37 additions & 11 deletions packages/apps/src/microsoft_teams/apps/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import deque
from typing import Awaitable, Callable, List, Optional, Union

from httpx import HTTPStatusError
from microsoft_teams.api import (
ApiClient,
Attachment,
Expand All @@ -20,7 +21,7 @@
)
from microsoft_teams.common import EventEmitter

from .plugins.streamer import StreamerEvent, StreamerProtocol
from .plugins.streamer import StreamCancelledError, StreamerEvent, StreamerProtocol
from .utils import RetryOptions, retry

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(self, client: ApiClient, ref: ConversationReference):
self._total_wait_timeout: float = 30.0
self._state_changed = asyncio.Event()

self._canceled = False
self._reset_state()

def _reset_state(self) -> None:
Expand All @@ -74,6 +76,14 @@ def _reset_state(self) -> None:
self._entities: List[Entity] = []
self._queue: deque[Union[MessageActivityInput, TypingActivityInput, str]] = deque()

@property
def canceled(self) -> bool:
"""
Whether the stream has been canceled.
For example when the user pressed the Stop button or the 2-minute timeout has exceeded.
"""
return self._canceled

@property
def closed(self) -> bool:
"""Whether the final stream message has been sent."""
Expand Down Expand Up @@ -103,6 +113,9 @@ def emit(self, activity: Union[MessageActivityInput, TypingActivityInput, str])
activity: The activity to emit.
"""

if self._canceled:
raise StreamCancelledError("Stream has been cancelled.")

if isinstance(activity, str):
activity = MessageActivityInput(text=activity, type="message")
self._queue.append(activity)
Expand All @@ -124,7 +137,7 @@ async def _wait_for_id_and_queue(self):
"""Wait until _id is set and the queue is empty, with a total timeout."""

async def _poll():
while self._queue or not self._id:
while (self._queue or not self._id) and not self._canceled:
await self._state_changed.wait()
self._state_changed.clear()

Expand All @@ -140,6 +153,10 @@ async def close(self) -> Optional[SentActivity]:
logger.debug("stream already closed with result")
return self._result

if self._canceled:
logger.debug("stream was cancelled, nothing to close")
return None

if self._index == 1 and not self._queue and not self._lock.locked():
logger.debug("stream has no content to send, returning None")
return None
Expand Down Expand Up @@ -229,13 +246,11 @@ async def _flush(self) -> None:
if self._queue and not self._timeout:
self._timeout = asyncio.get_running_loop().call_later(0.5, lambda: asyncio.create_task(self._flush()))

# Notify that queue state has changed
self._state_changed.set()

finally:
# Reset flushing flag so future emits can trigger another flush
self._pending = None
self._lock.release()
self._state_changed.set()

async def _send_activity(self, to_send: TypingActivityInput):
"""
Expand Down Expand Up @@ -265,12 +280,23 @@ async def _send(self, to_send: Union[TypingActivityInput, MessageActivityInput])
Args:
activity: The activity to send.
"""
if self._canceled:
logger.warning("Teams channel stopped the stream.")
raise StreamCancelledError("Teams channel stopped the stream.")

to_send.from_ = self._ref.bot
to_send.conversation = self._ref.conversation

if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])):
res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send)
else:
res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send)

return SentActivity.merge(to_send, res)
try:
if to_send.id and not any(e.type == "streaminfo" for e in (to_send.entities or [])):
res = await self._client.conversations.activities(self._ref.conversation.id).update(to_send.id, to_send)
else:
res = await self._client.conversations.activities(self._ref.conversation.id).create(to_send)

return SentActivity.merge(to_send, res)
except HTTPStatusError as e:
if e.response.status_code == 403:
self._canceled = True
logger.warning("Teams channel stopped the stream.")
raise StreamCancelledError("Teams channel stopped the stream.") from e
raise
3 changes: 2 additions & 1 deletion packages/apps/src/microsoft_teams/apps/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from .plugin_base import PluginBase
from .plugin_error_event import PluginErrorEvent
from .plugin_start_event import PluginStartEvent
from .streamer import StreamerProtocol
from .streamer import StreamCancelledError, StreamerProtocol

__all__ = [
"PluginBase",
"StreamCancelledError",
"StreamerProtocol",
"PluginActivityEvent",
"PluginActivityResponseEvent",
Expand Down
15 changes: 15 additions & 0 deletions packages/apps/src/microsoft_teams/apps/plugins/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,31 @@
Licensed under the MIT License.
"""

import asyncio
from typing import Awaitable, Callable, Literal, Optional, Protocol, Union

from microsoft_teams.api import MessageActivityInput, SentActivity, TypingActivityInput

StreamerEvent = Literal["chunk", "close"]


class StreamCancelledError(asyncio.CancelledError):
"""Raised when a stream operation is attempted after the stream has been cancelled."""

pass


class StreamerProtocol(Protocol):
"""Component that can send streamed chunks of an activity."""

@property
def canceled(self) -> bool:
"""
Whether the stream has been canceled.
For example when the user pressed the Stop button or the 2-minute timeout has exceeded.
"""
...

@property
def closed(self) -> bool:
"""Whether the final stream message has been sent."""
Expand Down
109 changes: 109 additions & 0 deletions packages/apps/tests/test_http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import MagicMock, patch

import pytest
from httpx import HTTPStatusError, Request, Response
from microsoft_teams.api import (
Account,
ApiClient,
Expand All @@ -17,6 +18,7 @@
TypingActivityInput,
)
from microsoft_teams.apps import HttpStream
from microsoft_teams.apps.plugins import StreamCancelledError


class TestHttpStream:
Expand Down Expand Up @@ -219,3 +221,110 @@ async def emit_task():
await self._run_scheduled_flushes(scheduled)

assert max_concurrent_entries == 1

@pytest.mark.asyncio
async def test_stream_canceled_on_403(self, mock_api_client, conversation_reference, patch_loop_call_later):
loop = asyncio.get_running_loop()
patcher, scheduled = patch_loop_call_later(loop)
with patcher:

async def mock_send_403(activity):
raise HTTPStatusError(
"Forbidden",
request=Request("POST", "https://example.com"),
response=Response(403),
)

mock_api_client.conversations.activities().create = mock_send_403
stream = HttpStream(mock_api_client, conversation_reference)

stream.emit("Test message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is True

@pytest.mark.asyncio
async def test_emit_blocked_after_cancel(self, mock_api_client, conversation_reference, patch_loop_call_later):
loop = asyncio.get_running_loop()
patcher, scheduled = patch_loop_call_later(loop)
with patcher:

async def mock_send_403(activity):
raise HTTPStatusError(
"Forbidden",
request=Request("POST", "https://example.com"),
response=Response(403),
)

mock_api_client.conversations.activities().create = mock_send_403
stream = HttpStream(mock_api_client, conversation_reference)

stream.emit("First message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is True

# Emit after cancel should raise
with pytest.raises(StreamCancelledError, match="Stream has been cancelled."):
stream.emit("Should be ignored")

@pytest.mark.asyncio
async def test_send_blocked_after_cancel(self, mock_api_client, conversation_reference):
stream = HttpStream(mock_api_client, conversation_reference)
stream._canceled = True

with pytest.raises(StreamCancelledError, match="Teams channel stopped the stream."):
await stream._send(TypingActivityInput(text="test"))

@pytest.mark.asyncio
async def test_stream_canceled_after_successful_message(
self, mock_api_client, conversation_reference, patch_loop_call_later
):
call_count = 0
loop = asyncio.get_running_loop()
patcher, scheduled = patch_loop_call_later(loop)
with patcher:

async def mock_send_then_403(activity):
nonlocal call_count
call_count += 1
if call_count == 1:
return SentActivity(id="activity-1", activity_params=activity)
raise HTTPStatusError(
"Forbidden",
request=Request("POST", "https://example.com"),
response=Response(403),
)

mock_api_client.conversations.activities().create = mock_send_then_403
stream = HttpStream(mock_api_client, conversation_reference)

# First emit succeeds
stream.emit("First message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is False
assert call_count == 1

# Second emit triggers 403
stream.emit("Second message")
await asyncio.sleep(0)
await self._run_scheduled_flushes(scheduled)

assert stream.canceled is True
assert call_count == 2

# Further emits raise
with pytest.raises(StreamCancelledError, match="Stream has been cancelled."):
stream.emit("Should be ignored")

@pytest.mark.asyncio
async def test_close_returns_none_when_canceled(self, mock_api_client, conversation_reference):
stream = HttpStream(mock_api_client, conversation_reference)
stream._canceled = True

result = await stream.close()
assert result is None
Loading