Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
a227f4f
test: improve coverage
ognis1205 Jul 15, 2025
de01d0f
chore: remove print
ognis1205 Jul 15, 2025
6541934
fix: remove await
ognis1205 Jul 18, 2025
77f3574
Merge branch 'main' into chore/improve-coverage-grpc-client
holtskinner Jul 21, 2025
b6523a1
Merge branch 'main' into chore/improve-coverage-grpc-client
holtskinner Jul 24, 2025
c25ccdb
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Jul 31, 2025
c340c78
fix: update unit tests regarding refactoring
ognis1205 Aug 4, 2025
e984261
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Aug 4, 2025
00be6cb
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Aug 4, 2025
0a0417f
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Aug 8, 2025
35378ab
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Aug 13, 2025
f5b2340
fix: allow any characters in task name match pattern
ognis1205 Aug 13, 2025
ade67af
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Aug 20, 2025
cfbd278
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Aug 20, 2025
eb9d426
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Sep 3, 2025
9c0a5be
Merge remote-tracking branch 'upstream/main' into chore/improve-cover…
ognis1205 Sep 5, 2025
fc69e82
Merge branch 'main' into chore/improve-coverage-grpc-client
holtskinner Sep 9, 2025
31daf62
Apply suggestions from code review
holtskinner Sep 9, 2025
dcc1e9f
Address Gemini code assist comments
holtskinner Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@


# Regexp patterns for matching
_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)')
_TASK_NAME_MATCH = re.compile(r'tasks/([^/]+)')
_TASK_PUSH_CONFIG_NAME_MATCH = re.compile(
r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
r'tasks/([^/]+)/pushNotificationConfigs/([^/]+)'
)


Expand Down
256 changes: 251 additions & 5 deletions tests/client/test_grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,45 @@
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock

import grpc
import pytest

from a2a.client.transports.grpc import GrpcTransport
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
from a2a.types import (
AgentCapabilities,
AgentCard,
Artifact,
GetTaskPushNotificationConfigParams,
Message,
MessageSendParams,
Part,
PushNotificationAuthenticationInfo,
PushNotificationConfig,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskPushNotificationConfig,
TaskQueryParams,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
from a2a.utils import get_text_parts, proto_utils
from a2a.utils.errors import ServerError


# Fixtures
@pytest.fixture
def mock_grpc_stub() -> AsyncMock:
"""Provides a mock gRPC stub with methods mocked."""
stub = AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
stub.SendMessage = AsyncMock()
stub.SendStreamingMessage = AsyncMock()
stub.SendStreamingMessage = MagicMock()
stub.GetTask = AsyncMock()
stub.CancelTask = AsyncMock()
stub.CreateTaskPushNotification = AsyncMock()
stub.GetTaskPushNotification = AsyncMock()
stub.CreateTaskPushNotificationConfig = AsyncMock()
stub.GetTaskPushNotificationConfig = AsyncMock()
return stub


Expand Down Expand Up @@ -93,6 +101,78 @@ def sample_message() -> Message:
)


@pytest.fixture
def sample_artifact() -> Artifact:
"""Provides a sample Artifact object."""
return Artifact(
artifact_id='artifact-1',
name='example.txt',
description='An example artifact',
parts=[Part(root=TextPart(text='Hi there'))],
metadata={},
extensions=[],
)


@pytest.fixture
def sample_task_status_update_event() -> TaskStatusUpdateEvent:
"""Provides a sample TaskStatusUpdateEvent."""
return TaskStatusUpdateEvent(
task_id='task-1',
context_id='ctx-1',
status=TaskStatus(state=TaskState.working),
final=False,
metadata={},
)


@pytest.fixture
def sample_task_artifact_update_event(
sample_artifact,
) -> TaskArtifactUpdateEvent:
"""Provides a sample TaskArtifactUpdateEvent."""
return TaskArtifactUpdateEvent(
task_id='task-1',
context_id='ctx-1',
artifact=sample_artifact,
append=True,
last_chunk=True,
metadata={},
)


@pytest.fixture
def sample_authentication_info() -> PushNotificationAuthenticationInfo:
"""Provides a sample AuthenticationInfo object."""
return PushNotificationAuthenticationInfo(
schemes=['apikey', 'oauth2'], credentials='secret-token'
)


@pytest.fixture
def sample_push_notification_config(
sample_authentication_info: PushNotificationAuthenticationInfo,
) -> PushNotificationConfig:
"""Provides a sample PushNotificationConfig object."""
return PushNotificationConfig(
id='config-1',
url='https://example.com/notify',
token='example-token',
authentication=sample_authentication_info,
)


@pytest.fixture
def sample_task_push_notification_config(
sample_push_notification_config: PushNotificationConfig,
) -> TaskPushNotificationConfig:
"""Provides a sample TaskPushNotificationConfig object."""
return TaskPushNotificationConfig(
task_id='task-1',
push_notification_config=sample_push_notification_config,
)


@pytest.mark.asyncio
async def test_send_message_task_response(
grpc_transport: GrpcTransport,
Expand Down Expand Up @@ -134,6 +214,57 @@ async def test_send_message_message_response(
)


@pytest.mark.asyncio
async def test_send_message_streaming( # noqa: PLR0913
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_message_send_params: MessageSendParams,
sample_message: Message,
sample_task: Task,
sample_task_status_update_event: TaskStatusUpdateEvent,
sample_task_artifact_update_event: TaskArtifactUpdateEvent,
):
"""Test send_message_streaming that yields responses."""
stream = MagicMock()
stream.read = AsyncMock(
side_effect=[
a2a_pb2.StreamResponse(
msg=proto_utils.ToProto.message(sample_message)
),
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)),
a2a_pb2.StreamResponse(
status_update=proto_utils.ToProto.task_status_update_event(
sample_task_status_update_event
)
),
a2a_pb2.StreamResponse(
artifact_update=proto_utils.ToProto.task_artifact_update_event(
sample_task_artifact_update_event
)
),
grpc.aio.EOF,
]
)
mock_grpc_stub.SendStreamingMessage.return_value = stream

responses = [
response
async for response in grpc_transport.send_message_streaming(
sample_message_send_params
)
]

mock_grpc_stub.SendStreamingMessage.assert_called_once()
assert isinstance(responses[0], Message)
assert responses[0].message_id == sample_message.message_id
assert isinstance(responses[1], Task)
assert responses[1].id == sample_task.id
assert isinstance(responses[2], TaskStatusUpdateEvent)
assert responses[2].task_id == sample_task_status_update_event.task_id
assert isinstance(responses[3], TaskArtifactUpdateEvent)
assert responses[3].task_id == sample_task_artifact_update_event.task_id


@pytest.mark.asyncio
async def test_get_task(
grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task
Expand Down Expand Up @@ -188,3 +319,118 @@ async def test_cancel_task(
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
)
assert response.status.state == TaskState.canceled


@pytest.mark.asyncio
async def test_set_task_callback_with_valid_task(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test setting a task push notification config with a valid task id."""
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = (
proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
)
)

response = await grpc_transport.set_task_callback(
sample_task_push_notification_config
)

mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with(
a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{sample_task_push_notification_config.task_id}',
config_id=sample_task_push_notification_config.push_notification_config.id,
config=proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
),
)
)
assert response.task_id == sample_task_push_notification_config.task_id


@pytest.mark.asyncio
async def test_set_task_callback_with_invalid_task(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test setting a task push notification config with an invalid task id."""
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig(
name=(
f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/'
f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}'
),
push_notification_config=proto_utils.ToProto.push_notification_config(
sample_task_push_notification_config.push_notification_config
),
)

with pytest.raises(ServerError) as exc_info:
await grpc_transport.set_task_callback(
sample_task_push_notification_config
)
assert (
'Bad TaskPushNotificationConfig resource name'
in exc_info.value.error.message
)


@pytest.mark.asyncio
async def test_get_task_callback_with_valid_task(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test retrieving a task push notification config with a valid task id."""
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = (
proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
)
)
params = GetTaskPushNotificationConfigParams(
id=sample_task_push_notification_config.task_id,
push_notification_config_id=sample_task_push_notification_config.push_notification_config.id,
)

response = await grpc_transport.get_task_callback(params)

mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with(
a2a_pb2.GetTaskPushNotificationConfigRequest(
name=(
f'tasks/{params.id}/'
f'pushNotificationConfigs/{params.push_notification_config_id}'
),
)
)
assert response.task_id == sample_task_push_notification_config.task_id


@pytest.mark.asyncio
async def test_get_task_callback_with_invalid_task(
grpc_transport: GrpcTransport,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test retrieving a task push notification config with an invalid task id."""
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = a2a_pb2.TaskPushNotificationConfig(
name=(
f'invalid-path-to-tasks/{sample_task_push_notification_config.task_id}/'
f'pushNotificationConfigs/{sample_task_push_notification_config.push_notification_config.id}'
),
push_notification_config=proto_utils.ToProto.push_notification_config(
sample_task_push_notification_config.push_notification_config
),
)
params = GetTaskPushNotificationConfigParams(
id=sample_task_push_notification_config.task_id,
push_notification_config_id=sample_task_push_notification_config.push_notification_config.id,
)

with pytest.raises(ServerError) as exc_info:
await grpc_transport.get_task_callback(params)
assert (
'Bad TaskPushNotificationConfig resource name'
in exc_info.value.error.message
)
Loading