From 2be38bdac87ba2b6addb2fd97139d2b95f4cd388 Mon Sep 17 00:00:00 2001 From: Swapnil agarwal Date: Sat, 26 Jul 2025 18:49:55 -0700 Subject: [PATCH 1/5] chore: add test for non-blocking sendMessage --- .../test_default_request_handler.py | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index d440cf2f..6a868067 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -50,7 +50,9 @@ TextPart, UnsupportedOperationError, ) - +from a2a.utils import ( + new_task, +) class DummyAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): @@ -75,7 +77,6 @@ async def _run(self): async def cancel(self, context: RequestContext, event_queue: EventQueue): pass - # Helper to create a simple task for tests def create_sample_task( task_id='task1', status_state=TaskState.submitted, context_id='ctx1' @@ -579,6 +580,75 @@ async def test_on_message_send_task_id_mismatch(): assert 'Task ID mismatch' in exc_info.value.error.message # type: ignore +class HelloAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + task = context.current_task + if not task: + task = new_task(context.message) # type: ignore + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + parts = [Part(root=TextPart(text=f'I am working'))] + await updater.update_status( + TaskState.working, + message=updater.new_agent_message(parts), + ) + except RuntimeError as e: + # Stop processing when the event loop is closed + print("Runtim error", e) + return + await updater.add_artifact( + [Part(root=TextPart(text="Hello world!"))], + name='conversion_result', + ) + await updater.complete() + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + +@pytest.mark.asyncio +async def test_on_message_send_non_blocking(): + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() + + request_handler = DefaultRequestHandler( + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, + ) + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_push', + parts=[Part(root=TextPart(text=f'Hi'))] + ), + configuration=MessageSendConfiguration( + blocking=False, + accepted_output_modes=['text/plain'] + ) + ) + + result = await request_handler.on_message_send( + params, create_server_call_context() + ) + + assert result is not None + assert type(result) == Task + result.status.state = TaskState.submitted + + # Polling for 500ms until task is completed. + task: Task | None = None + for _ in range(5): + await asyncio.sleep(0.1) + task = await task_store.get(result.id) + assert task is not None + if task.status.state == TaskState.completed: + break + + assert task is not None + assert task.status.state == TaskState.completed + @pytest.mark.asyncio async def test_on_message_send_interrupted_flow(): """Test on_message_send when flow is interrupted (e.g., auth_required).""" From 3c7eca588e6cbb6a01b395d8290b0a530e74f953 Mon Sep 17 00:00:00 2001 From: Swapnil agarwal Date: Sat, 26 Jul 2025 18:58:20 -0700 Subject: [PATCH 2/5] address comments --- .../server/request_handlers/test_default_request_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6a868067..8769dd6e 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -584,7 +584,8 @@ class HelloAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): task = context.current_task if not task: - task = new_task(context.message) # type: ignore + assert context.message is not None, "A message is required to create a new task" + task = new_task(context.message) # type: ignore await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) @@ -635,7 +636,7 @@ async def test_on_message_send_non_blocking(): assert result is not None assert type(result) == Task - result.status.state = TaskState.submitted + assert result.status.state == TaskState.submitted # Polling for 500ms until task is completed. task: Task | None = None From aa34a65fcd499957c2724cc442959267e5ba24da Mon Sep 17 00:00:00 2001 From: Swapnil agarwal Date: Sat, 26 Jul 2025 19:04:20 -0700 Subject: [PATCH 3/5] address more comments --- .../request_handlers/test_default_request_handler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 8769dd6e..2fae6868 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -1,5 +1,6 @@ import asyncio import time +import logging from unittest.mock import ( AsyncMock, @@ -595,9 +596,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): TaskState.working, message=updater.new_agent_message(parts), ) - except RuntimeError as e: + except Exception as e: # Stop processing when the event loop is closed - print("Runtim error", e) + logging.warning("Error: %s", e) return await updater.add_artifact( [Part(root=TextPart(text="Hello world!"))], @@ -635,7 +636,7 @@ async def test_on_message_send_non_blocking(): ) assert result is not None - assert type(result) == Task + assert isinstance(result, Task) assert result.status.state == TaskState.submitted # Polling for 500ms until task is completed. From 82070816dd209e8373eddf4e90af97982681ba55 Mon Sep 17 00:00:00 2001 From: Swapnil agarwal Date: Sat, 26 Jul 2025 19:06:44 -0700 Subject: [PATCH 4/5] lint code --- .../test_default_request_handler.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 2fae6868..f080d4a1 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -55,6 +55,7 @@ new_task, ) + class DummyAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): task_updater = TaskUpdater( @@ -78,6 +79,7 @@ async def _run(self): async def cancel(self, context: RequestContext, event_queue: EventQueue): pass + # Helper to create a simple task for tests def create_sample_task( task_id='task1', status_state=TaskState.submitted, context_id='ctx1' @@ -585,8 +587,10 @@ class HelloAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): task = context.current_task if not task: - assert context.message is not None, "A message is required to create a new task" - task = new_task(context.message) # type: ignore + assert context.message is not None, ( + 'A message is required to create a new task' + ) + task = new_task(context.message) # type: ignore await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) @@ -598,10 +602,10 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): ) except Exception as e: # Stop processing when the event loop is closed - logging.warning("Error: %s", e) + logging.warning('Error: %s', e) return await updater.add_artifact( - [Part(root=TextPart(text="Hello world!"))], + [Part(root=TextPart(text='Hello world!'))], name='conversion_result', ) await updater.complete() @@ -609,6 +613,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): async def cancel(self, context: RequestContext, event_queue: EventQueue): pass + @pytest.mark.asyncio async def test_on_message_send_non_blocking(): task_store = InMemoryTaskStore() @@ -623,12 +628,11 @@ async def test_on_message_send_non_blocking(): message=Message( role=Role.user, message_id='msg_push', - parts=[Part(root=TextPart(text=f'Hi'))] + parts=[Part(root=TextPart(text=f'Hi'))], ), configuration=MessageSendConfiguration( - blocking=False, - accepted_output_modes=['text/plain'] - ) + blocking=False, accepted_output_modes=['text/plain'] + ), ) result = await request_handler.on_message_send( @@ -651,6 +655,7 @@ async def test_on_message_send_non_blocking(): assert task is not None assert task.status.state == TaskState.completed + @pytest.mark.asyncio async def test_on_message_send_interrupted_flow(): """Test on_message_send when flow is interrupted (e.g., auth_required).""" From e52a2acc53f9ce7e928197b263f004edae08ea8b Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Tue, 29 Jul 2025 18:18:25 +0100 Subject: [PATCH 5/5] Lint fixes --- .../server/request_handlers/test_default_request_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index f080d4a1..6cb21662 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -1,6 +1,6 @@ import asyncio -import time import logging +import time from unittest.mock import ( AsyncMock, @@ -595,7 +595,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): updater = TaskUpdater(event_queue, task.id, task.context_id) try: - parts = [Part(root=TextPart(text=f'I am working'))] + parts = [Part(root=TextPart(text='I am working'))] await updater.update_status( TaskState.working, message=updater.new_agent_message(parts), @@ -628,7 +628,7 @@ async def test_on_message_send_non_blocking(): message=Message( role=Role.user, message_id='msg_push', - parts=[Part(root=TextPart(text=f'Hi'))], + parts=[Part(root=TextPart(text='Hi'))], ), configuration=MessageSendConfiguration( blocking=False, accepted_output_modes=['text/plain']