|
2 | 2 |
|
3 | 3 | from collections.abc import Awaitable, Callable |
4 | 4 | from typing import Any |
| 5 | +from unittest.mock import patch |
5 | 6 |
|
6 | 7 | from agent_framework import ( |
7 | 8 | Agent, |
@@ -296,6 +297,35 @@ async def counting_middleware(context: ChatContext, call_next: Callable[[], Awai |
296 | 297 | assert response3 is not None |
297 | 298 | assert execution_count["count"] == 2 # Should be 2 now |
298 | 299 |
|
| 300 | + async def test_run_level_middleware_is_not_forwarded_to_inner_client( |
| 301 | + self, chat_client_base: "MockBaseChatClient" |
| 302 | + ) -> None: |
| 303 | + """Test that run-level middleware stays in the middleware pipeline only.""" |
| 304 | + observed_context_kwargs: dict[str, Any] = {} |
| 305 | + |
| 306 | + @chat_middleware |
| 307 | + async def inspecting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None: |
| 308 | + observed_context_kwargs.update(context.kwargs) |
| 309 | + await call_next() |
| 310 | + |
| 311 | + async def fake_inner_get_response(**kwargs: Any) -> ChatResponse: |
| 312 | + assert "middleware" not in kwargs |
| 313 | + return ChatResponse(messages=[Message(role="assistant", text="ok")]) |
| 314 | + |
| 315 | + with patch.object( |
| 316 | + chat_client_base, |
| 317 | + "_inner_get_response", |
| 318 | + side_effect=fake_inner_get_response, |
| 319 | + ) as mock_inner_get_response: |
| 320 | + response = await chat_client_base.get_response( |
| 321 | + [Message(role="user", text="hello")], |
| 322 | + client_kwargs={"middleware": [inspecting_middleware], "trace_id": "trace-123"}, |
| 323 | + ) |
| 324 | + |
| 325 | + assert response.messages[0].text == "ok" |
| 326 | + assert observed_context_kwargs == {"trace_id": "trace-123"} |
| 327 | + mock_inner_get_response.assert_called_once() |
| 328 | + |
299 | 329 | async def test_chat_client_middleware_can_access_and_override_options( |
300 | 330 | self, chat_client_base: "MockBaseChatClient" |
301 | 331 | ) -> None: |
|
0 commit comments