diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index d440cf2f..6cb21662 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -1,4 +1,5 @@ import asyncio +import logging import time from unittest.mock import ( @@ -50,6 +51,9 @@ TextPart, UnsupportedOperationError, ) +from a2a.utils import ( + new_task, +) class DummyAgentExecutor(AgentExecutor): @@ -579,6 +583,79 @@ async def test_on_message_send_task_id_mismatch(): assert 'Task ID mismatch' in exc_info.value.error.message # type: ignore +class HelloAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = context.current_task + if not task: + assert context.message is not None, ( + 'A message is required to create a new task' + ) + task = new_task(context.message) # type: ignore + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + parts = [Part(root=TextPart(text='I am working'))] + await updater.update_status( + TaskState.working, + message=updater.new_agent_message(parts), + ) + except Exception as e: + # Stop processing when the event loop is closed + logging.warning('Error: %s', e) + return + await updater.add_artifact( + [Part(root=TextPart(text='Hello world!'))], + name='conversion_result', + ) + await updater.complete() + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +@pytest.mark.asyncio +async def test_on_message_send_non_blocking(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandler( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_push', + parts=[Part(root=TextPart(text='Hi'))], + ), + configuration=MessageSendConfiguration( + blocking=False, accepted_output_modes=['text/plain'] + ), + ) + + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + + assert result is not None + assert isinstance(result, Task) + assert result.status.state == TaskState.submitted + + # Polling for 500ms until task is completed. + task: Task | None = None + for _ in range(5): + await asyncio.sleep(0.1) + task = await task_store.get(result.id) + assert task is not None + if task.status.state == TaskState.completed: + break + + assert task is not None + assert task.status.state == TaskState.completed + + @pytest.mark.asyncio async def test_on_message_send_interrupted_flow(): """Test on_message_send when flow is interrupted (e.g., auth_required)."""