Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ async def _map_a2a_stream(
)

all_updates: list[AgentResponseUpdate] = []
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
async for item in a2a_stream:
if isinstance(item, A2AMessage):
# Process A2A Message
Expand All @@ -378,12 +379,21 @@ async def _map_a2a_stream(
all_updates.append(update)
yield update
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task):
task, _update_event = item
for update in self._updates_from_task(
task, update_event = item
updates = self._updates_from_task(
task,
update_event=update_event,
background=background,
emit_intermediate=emit_intermediate,
streamed_artifact_ids=streamed_artifact_ids_by_task.get(task.id),
)
if isinstance(update_event, TaskArtifactUpdateEvent) and any(
update.raw_representation is update_event for update in updates
):
streamed_artifact_ids_by_task.setdefault(task.id, set()).add(update_event.artifact.artifact_id)
if task.status.state in TERMINAL_TASK_STATES:
streamed_artifact_ids_by_task.pop(task.id, None)
for update in updates:
all_updates.append(update)
yield update
else:
Expand All @@ -403,8 +413,10 @@ def _updates_from_task(
self,
task: Task,
*,
update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None,
background: bool = False,
emit_intermediate: bool = False,
streamed_artifact_ids: set[str] | None = None,
) -> list[AgentResponseUpdate]:
"""Convert an A2A Task into AgentResponseUpdate(s).

Expand All @@ -418,8 +430,18 @@ def _updates_from_task(
"""
status = task.status

if emit_intermediate and update_event is not None:
if event_updates := self._updates_from_task_update_event(update_event):
return event_updates

if status.state in TERMINAL_TASK_STATES:
task_messages = self._parse_messages_from_task(task)
if task.artifacts is not None and streamed_artifact_ids:
task_messages = [
message
for message in task_messages
if getattr(message.raw_representation, "artifact_id", None) not in streamed_artifact_ids
]
if task_messages:
return [
AgentResponseUpdate(
Expand All @@ -431,6 +453,8 @@ def _updates_from_task(
)
for message in task_messages
]
if task.artifacts is not None:
return []
return [AgentResponseUpdate(contents=[], role="assistant", response_id=task.id, raw_representation=task)]

if background and status.state in IN_PROGRESS_TASK_STATES:
Expand Down Expand Up @@ -467,6 +491,44 @@ def _updates_from_task(

return []

def _updates_from_task_update_event(
self, update_event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent
) -> list[AgentResponseUpdate]:
"""Convert A2A task update events into streaming AgentResponseUpdates."""
if isinstance(update_event, TaskArtifactUpdateEvent):
contents = self._parse_contents_from_a2a(update_event.artifact.parts)
if not contents:
return []
return [
AgentResponseUpdate(
contents=contents,
role="assistant",
response_id=update_event.task_id,
message_id=update_event.artifact.artifact_id,
raw_representation=update_event,
)
]

if not isinstance(update_event, TaskStatusUpdateEvent):
return []

message = update_event.status.message
if message is None or not message.parts:
return []

contents = self._parse_contents_from_a2a(message.parts)
if not contents:
return []

return [
AgentResponseUpdate(
contents=contents,
role="assistant" if message.role == A2ARole.agent else "user",
response_id=update_event.task_id,
raw_representation=update_event,
)
]

@staticmethod
def _build_continuation_token(task: Task) -> A2AContinuationToken | None:
"""Build an A2AContinuationToken from an A2A Task if it is still in progress."""
Expand Down
199 changes: 199 additions & 0 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
FileWithUri,
Part,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
from a2a.types import Message as A2AMessage
Expand Down Expand Up @@ -1189,4 +1191,201 @@ async def test_streaming_working_update_with_empty_parts_is_skipped(
assert updates[0].contents[0].text == "Result"


async def test_streaming_artifact_update_event_yields_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streaming artifact update events yield incremental content."""
task = Task(id="task-art", context_id="ctx-art", status=TaskStatus(state=TaskState.working, message=None))
artifact = Artifact(
artifact_id="artifact-1",
parts=[Part(root=TextPart(text="Hello"))],
)
update_event = TaskArtifactUpdateEvent(task_id="task-art", context_id="ctx-art", artifact=artifact, append=False)
mock_a2a_client.responses.append((task, update_event))

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].text == "Hello"
assert updates[0].message_id == "artifact-1"
assert updates[0].raw_representation == update_event


async def test_streaming_status_update_event_yields_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streaming status update events surface message content directly from the update event."""
update_event = TaskStatusUpdateEvent(
task_id="task-status",
context_id="ctx-status",
status=TaskStatus(
state=TaskState.working,
message=A2AMessage(
message_id=str(uuid4()),
role=A2ARole.agent,
parts=[Part(root=TextPart(text="Still working"))],
),
),
final=False,
)
task = Task(id="task-status", context_id="ctx-status", status=TaskStatus(state=TaskState.working, message=None))
mock_a2a_client.responses.append((task, update_event))

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].text == "Still working"
assert updates[0].role == "assistant"
assert updates[0].raw_representation == update_event


async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_artifacts(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that streamed artifact chunks are not re-emitted from the final terminal task."""
working_task = Task(id="task-art-dup", context_id="ctx-art-dup", status=TaskStatus(state=TaskState.working))
first_chunk = TaskArtifactUpdateEvent(
task_id="task-art-dup",
context_id="ctx-art-dup",
artifact=Artifact(
artifact_id="artifact-dup",
parts=[Part(root=TextPart(text="Hello "))],
),
append=False,
)
second_chunk = TaskArtifactUpdateEvent(
task_id="task-art-dup",
context_id="ctx-art-dup",
artifact=Artifact(
artifact_id="artifact-dup",
parts=[Part(root=TextPart(text="world"))],
),
append=True,
)
terminal_task = Task(
id="task-art-dup",
context_id="ctx-art-dup",
status=TaskStatus(state=TaskState.completed, message=None),
artifacts=[
Artifact(
artifact_id="artifact-dup",
parts=[Part(root=TextPart(text="Hello world"))],
)
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-art-dup",
context_id="ctx-art-dup",
status=TaskStatus(state=TaskState.completed, message=None),
final=True,
)

mock_a2a_client.responses.extend(
[
(working_task, first_chunk),
(working_task, second_chunk),
(terminal_task, terminal_event),
]
)

stream = a2a_agent.run("Hello", stream=True)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
response = await stream.get_final_response()

assert [update.text for update in updates] == ["Hello ", "world"]
assert response.text == "Hello world"
assert len(response.messages) == 1


async def test_streaming_terminal_task_artifacts_are_emitted_when_terminal_event_has_no_content(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that terminal task artifacts are still emitted when the final status event has no message."""
terminal_task = Task(
id="task-art-final",
context_id="ctx-art-final",
status=TaskStatus(state=TaskState.completed, message=None),
artifacts=[
Artifact(
artifact_id="artifact-final",
parts=[Part(root=TextPart(text="Final artifact"))],
)
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-art-final",
context_id="ctx-art-final",
status=TaskStatus(state=TaskState.completed, message=None),
final=True,
)
mock_a2a_client.responses.append((terminal_task, terminal_event))

updates: list[AgentResponseUpdate] = []
async for update in a2a_agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].text == "Final artifact"
assert updates[0].message_id == "artifact-final"


async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
) -> None:
"""Test that the terminal task only emits artifacts that were not already streamed incrementally."""
working_task = Task(id="task-art-mixed", context_id="ctx-art-mixed", status=TaskStatus(state=TaskState.working))
streamed_chunk = TaskArtifactUpdateEvent(
task_id="task-art-mixed",
context_id="ctx-art-mixed",
artifact=Artifact(
artifact_id="artifact-streamed",
parts=[Part(root=TextPart(text="Hello"))],
),
append=False,
)
terminal_task = Task(
id="task-art-mixed",
context_id="ctx-art-mixed",
status=TaskStatus(state=TaskState.completed, message=None),
artifacts=[
Artifact(
artifact_id="artifact-streamed",
parts=[Part(root=TextPart(text="Hello"))],
),
Artifact(
artifact_id="artifact-final",
parts=[Part(root=TextPart(text="Goodbye"))],
),
],
)
terminal_event = TaskStatusUpdateEvent(
task_id="task-art-mixed",
context_id="ctx-art-mixed",
status=TaskStatus(state=TaskState.completed, message=None),
final=True,
)

mock_a2a_client.responses.extend(
[
(working_task, streamed_chunk),
(terminal_task, terminal_event),
]
)

stream = a2a_agent.run("Hello", stream=True)
updates: list[AgentResponseUpdate] = []
async for update in stream:
updates.append(update)
response = await stream.get_final_response()

assert [update.text for update in updates] == ["Hello", "Goodbye"]
assert [message.text for message in response.messages] == ["Hello", "Goodbye"]


# endregion
Loading