Skip to content

Commit 9752caa

Browse files
Address PR review feedback
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent ba3399b commit 9752caa

6 files changed

Lines changed: 81 additions & 5 deletions

File tree

python/packages/core/agent_framework/_mcp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
1414
from datetime import timedelta
1515
from functools import partial
16-
from typing import TYPE_CHECKING, Any, Literal, TypedDict
16+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast
1717

1818
import httpx
1919
from anyio import ClosedResourceError
@@ -733,9 +733,9 @@ async def sampling_callback(
733733
for msg in params.messages:
734734
messages.append(_parse_message_from_mcp(msg))
735735
try:
736-
chat_client: Any = self.client
737-
response: Any = await chat_client.get_response(
738-
messages,
736+
chat_client = cast("SupportsChatGetResponse[Any]", self.client)
737+
response = await chat_client.get_response(
738+
messages=messages,
739739
options={
740740
"temperature": params.temperature,
741741
"max_tokens": params.maxTokens,

python/packages/core/agent_framework/_middleware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,12 +1072,12 @@ def get_response(
10721072
"""Execute the chat pipeline if middleware is configured."""
10731073
super_get_response = super().get_response # type: ignore[misc]
10741074
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
1075+
call_middleware = effective_client_kwargs.pop("middleware", [])
10751076
context_kwargs = dict(effective_client_kwargs)
10761077
if compaction_strategy is not None:
10771078
context_kwargs["compaction_strategy"] = compaction_strategy
10781079
if tokenizer is not None:
10791080
context_kwargs["tokenizer"] = tokenizer
1080-
call_middleware = effective_client_kwargs.pop("middleware", [])
10811081
pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType]
10821082
if not pipeline.has_middlewares:
10831083
return super_get_response( # type: ignore[no-any-return]

python/packages/core/agent_framework/_tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,6 +2159,7 @@ def get_response(
21592159
options=mutable_options,
21602160
compaction_strategy=compaction_strategy,
21612161
tokenizer=tokenizer,
2162+
function_invocation_kwargs=function_invocation_kwargs,
21622163
client_kwargs=filtered_kwargs,
21632164
)
21642165
if not stream:

python/packages/core/tests/core/test_function_invocation_logic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Content,
1414
Message,
1515
SupportsChatGetResponse,
16+
chat_middleware,
1617
tool,
1718
)
1819
from agent_framework._compaction import (
@@ -1429,6 +1430,36 @@ def ai_func(arg1: str) -> str:
14291430
assert len(response.messages) > 0
14301431

14311432

1433+
async def test_function_invocation_config_enabled_false_preserves_invocation_kwargs(
1434+
chat_client_base: SupportsChatGetResponse,
1435+
):
1436+
"""Test disabled function invocation still forwards invocation kwargs downstream."""
1437+
captured_kwargs: dict[str, Any] = {}
1438+
1439+
@tool(name="test_function")
1440+
def ai_func(arg1: str) -> str:
1441+
return f"Processed {arg1}"
1442+
1443+
@chat_middleware
1444+
async def capture_middleware(context, call_next):
1445+
captured_kwargs.update(context.function_invocation_kwargs or {})
1446+
await call_next()
1447+
1448+
chat_client_base.chat_middleware = [capture_middleware]
1449+
chat_client_base.run_responses = [
1450+
ChatResponse(messages=Message(role="assistant", text="response without function calling")),
1451+
]
1452+
chat_client_base.function_invocation_configuration["enabled"] = False
1453+
1454+
await chat_client_base.get_response(
1455+
[Message(role="user", text="hello")],
1456+
options={"tool_choice": "auto", "tools": [ai_func]},
1457+
function_invocation_kwargs={"tool_request_id": "tool-123"},
1458+
)
1459+
1460+
assert captured_kwargs == {"tool_request_id": "tool-123"}
1461+
1462+
14321463
@pytest.mark.skip(reason="Error handling and failsafe behavior needs investigation in unified API")
14331464
async def test_function_invocation_config_max_consecutive_errors(chat_client_base: SupportsChatGetResponse):
14341465
"""Test that max_consecutive_errors_per_request limits error retries."""

python/packages/core/tests/core/test_mcp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,13 @@ async def test_mcp_tool_sampling_callback_chat_client_exception():
15731573
assert isinstance(result, types.ErrorData)
15741574
assert result.code == types.INTERNAL_ERROR
15751575
assert "Failed to get chat message content: Chat client error" in result.message
1576+
mock_chat_client.get_response.assert_awaited_once()
1577+
_, kwargs = mock_chat_client.get_response.await_args
1578+
assert kwargs["options"] == {
1579+
"temperature": None,
1580+
"max_tokens": None,
1581+
"stop": None,
1582+
}
15761583

15771584

15781585
async def test_mcp_tool_sampling_callback_no_valid_content():
@@ -1616,6 +1623,13 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
16161623
assert isinstance(result, types.ErrorData)
16171624
assert result.code == types.INTERNAL_ERROR
16181625
assert "Failed to get right content types from the response." in result.message
1626+
mock_chat_client.get_response.assert_awaited_once()
1627+
_, kwargs = mock_chat_client.get_response.await_args
1628+
assert kwargs["options"] == {
1629+
"temperature": None,
1630+
"max_tokens": None,
1631+
"stop": None,
1632+
}
16191633

16201634

16211635
# Test error handling in connect() method

python/packages/core/tests/core/test_middleware_with_chat.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Awaitable, Callable
44
from typing import Any
5+
from unittest.mock import patch
56

67
from agent_framework import (
78
Agent,
@@ -296,6 +297,35 @@ async def counting_middleware(context: ChatContext, call_next: Callable[[], Awai
296297
assert response3 is not None
297298
assert execution_count["count"] == 2 # Should be 2 now
298299

300+
async def test_run_level_middleware_is_not_forwarded_to_inner_client(
301+
self, chat_client_base: "MockBaseChatClient"
302+
) -> None:
303+
"""Test that run-level middleware stays in the middleware pipeline only."""
304+
observed_context_kwargs: dict[str, Any] = {}
305+
306+
@chat_middleware
307+
async def inspecting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
308+
observed_context_kwargs.update(context.kwargs)
309+
await call_next()
310+
311+
async def fake_inner_get_response(**kwargs: Any) -> ChatResponse:
312+
assert "middleware" not in kwargs
313+
return ChatResponse(messages=[Message(role="assistant", text="ok")])
314+
315+
with patch.object(
316+
chat_client_base,
317+
"_inner_get_response",
318+
side_effect=fake_inner_get_response,
319+
) as mock_inner_get_response:
320+
response = await chat_client_base.get_response(
321+
[Message(role="user", text="hello")],
322+
client_kwargs={"middleware": [inspecting_middleware], "trace_id": "trace-123"},
323+
)
324+
325+
assert response.messages[0].text == "ok"
326+
assert observed_context_kwargs == {"trace_id": "trace-123"}
327+
mock_inner_get_response.assert_called_once()
328+
299329
async def test_chat_client_middleware_can_access_and_override_options(
300330
self, chat_client_base: "MockBaseChatClient"
301331
) -> None:

0 commit comments

Comments
 (0)