diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 4b6d9cc19b..dbae17619a 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -365,6 +365,7 @@ async def _map_a2a_stream( ) all_updates: list[AgentResponseUpdate] = [] + streamed_artifact_ids_by_task: dict[str, set[str]] = {} async for item in a2a_stream: if isinstance(item, A2AMessage): # Process A2A Message @@ -378,12 +379,21 @@ async def _map_a2a_stream( all_updates.append(update) yield update elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task): - task, _update_event = item - for update in self._updates_from_task( + task, update_event = item + updates = self._updates_from_task( task, + update_event=update_event, background=background, emit_intermediate=emit_intermediate, + streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id), + ) + if isinstance(update_event, TaskArtifactUpdateEvent) and any( + update.raw_representation is update_event for update in updates ): + streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id) + if task.status.state in TERMINAL_TASK_STATES: + streamed_artifact_ids_by_task.pop(task.id, None) + for update in updates: all_updates.append(update) yield update else: @@ -403,8 +413,10 @@ def _updates_from_task( self, task: Task, *, + update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None, background: bool = False, emit_intermediate: bool = False, + streamed_artifact_ids: set[str] | None = None, ) -> list[AgentResponseUpdate]: """Convert an A2A Task into AgentResponseUpdate(s). @@ -418,8 +430,21 @@ def _updates_from_task( """ status = task.status + if ( + emit_intermediate + and update_event is not None + and (event_updates := self._updates_from_task_update_event(update_event)) + ): + return event_updates + if status.state in TERMINAL_TASK_STATES: task_messages = self._parse_messages_from_task(task) + if task.artifacts is not None and streamed_artifact_ids: + task_messages = [ + message + for message in task_messages + if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids + ] if task_messages: return [ AgentResponseUpdate( @@ -431,6 +456,8 @@ def _updates_from_task( ) for message in task_messages ] + if task.artifacts is not None: + return [] return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)] if background and status.state in IN_PROGRESS_TASK_STATES: @@ -467,6 +494,44 @@ def _updates_from_task( return [] + def _updates_from_task_update_event( + self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> list[AgentResponseUpdate]: + """Convert A2A task update events into streaming AgentResponseUpdates.""" + if isinstance(update_event, TaskArtifactUpdateEvent): + contents = self._parse_contents_from_a2a(update_event.artifact.parts) + if not contents: + return [] + return [ + AgentResponseUpdate( + contents=contents, + role="assistant", + response_id=update_event.task_id, + message_id=update_event.artifact.artifact_id, + raw_representation=update_event, + ) + ] + + if not isinstance(update_event, TaskStatusUpdateEvent): + return [] + + message = update_event.status.message + if message is None or not message.parts: + return [] + + contents = self._parse_contents_from_a2a(message.parts) + if not contents: + return [] + + return [ + AgentResponseUpdate( + contents=contents, + role="assistant" if message.role == A2ARole.agent else "user", + response_id=update_event.task_id, + raw_representation=update_event, + ) + ] + @staticmethod def _build_continuation_token(task: Task) -> A2AContinuationToken | None: """Build an A2AContinuationToken from an A2A Task if it is still in progress.""" diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 0d81179cd1..72cbd73586 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -14,8 +14,10 @@ FileWithUri, Part, Task, + TaskArtifactUpdateEvent, TaskState, TaskStatus, + TaskStatusUpdateEvent, TextPart, ) from a2a.types import Message as A2AMessage @@ -1189,4 +1191,201 @@ async def test_streaming_working_update_with_empty_parts_is_skipped( assert updates[0].contents[0].text == "Result" +async def test_streaming_artifact_update_event_yields_content( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that streaming artifact update events yield incremental content.""" + task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None)) + artifact = Artifact( + artifact_id="artifact-1", + parts=[Part(root=TextPart(text="Hello"))], + ) + update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False) + mock_a2a_client.responses.append((task, update_event)) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hello", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].text == "Hello" + assert updates[0].message_id == "artifact-1" + assert updates[0].raw_representation == update_event + + +async def test_streaming_status_update_event_yields_content( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that streaming status update events surface message content directly from the update event.""" + update_event = TaskStatusUpdateEvent( + task_id="task-status", + context_id="ctx-status", + status=TaskStatus( + state=TaskState.working, + message=A2AMessage( + message_id=str(uuid4()), + role=A2ARole.agent, + parts=[Part(root=TextPart(text="Still working"))], + ), + ), + final=False, + ) + task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None)) + mock_a2a_client.responses.append((task, update_event)) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hello", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].text == "Still working" + assert updates[0].role == "assistant" + assert updates[0].raw_representation == update_event + + +async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that streamed artifact chunks are not re-emitted from the final terminal task.""" + working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working)) + first_chunk = TaskArtifactUpdateEvent( + task_id="task-art-dup", + context_id="ctx-art-dup", + artifact=Artifact( + artifact_id="artifact-dup", + parts=[Part(root=TextPart(text="Hello "))], + ), + append=False, + ) + second_chunk = TaskArtifactUpdateEvent( + task_id="task-art-dup", + context_id="ctx-art-dup", + artifact=Artifact( + artifact_id="artifact-dup", + parts=[Part(root=TextPart(text="world"))], + ), + append=True, + ) + terminal_task = Task( + id="task-art-dup", + context_id="ctx-art-dup", + status=TaskStatus(state=TaskState.completed, message=None), + artifacts=[ + Artifact( + artifact_id="artifact-dup", + parts=[Part(root=TextPart(text="Hello world"))], + ) + ], + ) + terminal_event = TaskStatusUpdateEvent( + task_id="task-art-dup", + context_id="ctx-art-dup", + status=TaskStatus(state=TaskState.completed, message=None), + final=True, + ) + + mock_a2a_client.responses.extend( + [ + (working_task, first_chunk), + (working_task, second_chunk), + (terminal_task, terminal_event), + ] + ) + + stream = a2a_agent.run("Hello", stream=True) + updates: list[AgentResponseUpdate] = [] + async for update in stream: + updates.append(update) + response = await stream.get_final_response() + + assert [update.text for update in updates] == ["Hello ", "world"] + assert response.text == "Hello world" + assert len(response.messages) == 1 + + +async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event_has_no_content( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that terminal task artifacts are still emitted when the final status event has no message.""" + terminal_task = Task( + id="task-art-final", + context_id="ctx-art-final", + status=TaskStatus(state=TaskState.completed, message=None), + artifacts=[ + Artifact( + artifact_id="artifact-final", + parts=[Part(root=TextPart(text="Final artifact"))], + ) + ], + ) + terminal_event = TaskStatusUpdateEvent( + task_id="task-art-final", + context_id="ctx-art-final", + status=TaskStatus(state=TaskState.completed, message=None), + final=True, + ) + mock_a2a_client.responses.append((terminal_task, terminal_event)) + + updates: list[AgentResponseUpdate] = [] + async for update in a2a_agent.run("Hello", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].text == "Final artifact" + assert updates[0].message_id == "artifact-final" + + +async def test_streaming_terminal_task_only_emits_unstreamed_artifacts( + a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient +) -> None: + """Test that the terminal task only emits artifacts that were not already streamed incrementally.""" + working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working)) + streamed_chunk = TaskArtifactUpdateEvent( + task_id="task-art-mixed", + context_id="ctx-art-mixed", + artifact=Artifact( + artifact_id="artifact-streamed", + parts=[Part(root=TextPart(text="Hello"))], + ), + append=False, + ) + terminal_task = Task( + id="task-art-mixed", + context_id="ctx-art-mixed", + status=TaskStatus(state=TaskState.completed, message=None), + artifacts=[ + Artifact( + artifact_id="artifact-streamed", + parts=[Part(root=TextPart(text="Hello"))], + ), + Artifact( + artifact_id="artifact-final", + parts=[Part(root=TextPart(text="Goodbye"))], + ), + ], + ) + terminal_event = TaskStatusUpdateEvent( + task_id="task-art-mixed", + context_id="ctx-art-mixed", + status=TaskStatus(state=TaskState.completed, message=None), + final=True, + ) + + mock_a2a_client.responses.extend( + [ + (working_task, streamed_chunk), + (terminal_task, terminal_event), + ] + ) + + stream = a2a_agent.run("Hello", stream=True) + updates: list[AgentResponseUpdate] = [] + async for update in stream: + updates.append(update) + response = await stream.get_final_response() + + assert [update.text for update in updates] == ["Hello", "Goodbye"] + assert [message.text for message in response.messages] == ["Hello", "Goodbye"] + + # endregion