Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import time

from unittest.mock import (
Expand Down Expand Up @@ -50,6 +51,9 @@
TextPart,
UnsupportedOperationError,
)
from a2a.utils import (
new_task,
)


class DummyAgentExecutor(AgentExecutor):
Expand Down Expand Up @@ -579,6 +583,79 @@ async def test_on_message_send_task_id_mismatch():
assert 'Task ID mismatch' in exc_info.value.error.message # type: ignore


class HelloAgentExecutor(AgentExecutor):
async def execute(self, context: RequestContext, event_queue: EventQueue):
task = context.current_task
if not task:
assert context.message is not None, (
'A message is required to create a new task'
)
task = new_task(context.message) # type: ignore
await event_queue.enqueue_event(task)
updater = TaskUpdater(event_queue, task.id, task.context_id)

try:
parts = [Part(root=TextPart(text='I am working'))]
await updater.update_status(
TaskState.working,
message=updater.new_agent_message(parts),
)
except Exception as e:
# Stop processing when the event loop is closed
logging.warning('Error: %s', e)
return
await updater.add_artifact(
[Part(root=TextPart(text='Hello world!'))],
name='conversion_result',
)
await updater.complete()

async def cancel(self, context: RequestContext, event_queue: EventQueue):
pass


@pytest.mark.asyncio
async def test_on_message_send_non_blocking():
task_store = InMemoryTaskStore()
push_store = InMemoryPushNotificationConfigStore()

request_handler = DefaultRequestHandler(
agent_executor=HelloAgentExecutor(),
task_store=task_store,
push_config_store=push_store,
)
params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_push',
parts=[Part(root=TextPart(text='Hi'))],
),
configuration=MessageSendConfiguration(
blocking=False, accepted_output_modes=['text/plain']
),
)

result = await request_handler.on_message_send(
params, create_server_call_context()
)

assert result is not None
assert isinstance(result, Task)
assert result.status.state == TaskState.submitted

# Polling for 500ms until task is completed.
task: Task | None = None
for _ in range(5):
await asyncio.sleep(0.1)
task = await task_store.get(result.id)
assert task is not None
if task.status.state == TaskState.completed:
break

assert task is not None
assert task.status.state == TaskState.completed


@pytest.mark.asyncio
async def test_on_message_send_interrupted_flow():
"""Test on_message_send when flow is interrupted (e.g., auth_required)."""
Expand Down
Loading