Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import json
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -213,9 +214,10 @@ async def run_async(
],
)
else:
request_text = args.get('request') or json.dumps(args, ensure_ascii=False)
content = types.Content(
role='user',
parts=[types.Part.from_text(text=args['request'])],
parts=[types.Part.from_text(text=request_text)],
)
invocation_context = tool_context._invocation_context
parent_app_name = (
Expand Down
164 changes: 164 additions & 0 deletions tests/unittests/tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,167 @@ def test_empty_sequential_agent_falls_back_to_request(self):
}
else:
assert declaration.parameters.properties['request'].type == 'STRING'


@mark.asyncio
async def test_no_schema_non_request_args_serialized_as_json(monkeypatch):
"""AgentTool.run_async with no input_schema serializes non-'request' args as JSON.

Regression test for KeyError: 'request' — when the orchestrating LLM passes
kwargs other than 'request', the fallback path must not crash.
"""
captured = {}

async def _empty_async_generator():
if False:
yield None

class StubRunner:

def __init__(
self,
*,
app_name: str,
agent,
artifact_service,
session_service,
memory_service,
credential_service,
plugins,
):
del artifact_service, memory_service, credential_service
self.agent = agent
self.session_service = session_service
self.plugin_manager = PluginManager(plugins=plugins)
self.app_name = app_name

def run_async(
self,
*,
user_id: str,
session_id: str,
invocation_id=None,
new_message=None,
state_delta=None,
run_config=None,
):
captured['new_message'] = new_message
return _empty_async_generator()

async def close(self):
pass

monkeypatch.setattr('google.adk.runners.Runner', StubRunner)

tool_agent = Agent(name='tool_agent', model='test-model')
agent_tool = AgentTool(agent=tool_agent)
root_agent = Agent(name='root_agent', model='test-model', tools=[agent_tool])

session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='user'
)
invocation_context = InvocationContext(
artifact_service=InMemoryArtifactService(),
session_service=session_service,
memory_service=InMemoryMemoryService(),
plugin_manager=PluginManager(),
invocation_id='test-invocation',
agent=root_agent,
session=session,
run_config=RunConfig(),
)
tool_context = ToolContext(invocation_context)

# LLM passed structured kwargs instead of the 'request' key — must not crash
await agent_tool.run_async(
args={'brand': 'Nike', 'product': 'running shoes'},
tool_context=tool_context,
)

import json

assert captured['new_message'] is not None
text = captured['new_message'].parts[0].text
parsed = json.loads(text)
assert parsed == {'brand': 'Nike', 'product': 'running shoes'}


@mark.asyncio
async def test_no_schema_request_key_backward_compat(monkeypatch):
"""AgentTool.run_async with no input_schema keeps 'request' value as plain text.

Ensures the fix for non-'request' args does not break the existing contract
when the LLM correctly passes args={'request': '...'}.
"""
captured = {}

async def _empty_async_generator():
if False:
yield None

class StubRunner:

def __init__(
self,
*,
app_name: str,
agent,
artifact_service,
session_service,
memory_service,
credential_service,
plugins,
):
del artifact_service, memory_service, credential_service
self.agent = agent
self.session_service = session_service
self.plugin_manager = PluginManager(plugins=plugins)
self.app_name = app_name

def run_async(
self,
*,
user_id: str,
session_id: str,
invocation_id=None,
new_message=None,
state_delta=None,
run_config=None,
):
captured['new_message'] = new_message
return _empty_async_generator()

async def close(self):
pass

monkeypatch.setattr('google.adk.runners.Runner', StubRunner)

tool_agent = Agent(name='tool_agent', model='test-model')
agent_tool = AgentTool(agent=tool_agent)
root_agent = Agent(name='root_agent', model='test-model', tools=[agent_tool])

session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='user'
)
invocation_context = InvocationContext(
artifact_service=InMemoryArtifactService(),
session_service=session_service,
memory_service=InMemoryMemoryService(),
plugin_manager=PluginManager(),
invocation_id='test-invocation',
agent=root_agent,
session=session,
run_config=RunConfig(),
)
tool_context = ToolContext(invocation_context)

await agent_tool.run_async(
args={'request': 'find me Nike running shoes'},
tool_context=tool_context,
)

assert captured['new_message'] is not None
text = captured['new_message'].parts[0].text
assert text == 'find me Nike running shoes'