diff --git a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py index b625eaa1c..25b2fefb4 100644 --- a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py +++ b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py @@ -81,6 +81,9 @@ def __init__(self): # Agent info to return in healthz self.agent_id: str | None = None + # Optional agent card for registration metadata + self._agent_card: Any | None = None + @classmethod def create(cls): """Create and initialize BaseACPServer instance""" @@ -98,7 +101,7 @@ def get_lifespan_function(self): async def lifespan_context(app: FastAPI): # noqa: ARG001 env_vars = EnvironmentVariables.refresh() if env_vars.AGENTEX_BASE_URL: - await register_agent(env_vars) + await register_agent(env_vars, agent_card=self._agent_card) self.agent_id = env_vars.AGENT_ID else: logger.warning("AGENTEX_BASE_URL not set, skipping agent registration") diff --git a/src/agentex/lib/sdk/fastacp/fastacp.py b/src/agentex/lib/sdk/fastacp/fastacp.py index 0e32c3460..9e3ae78ec 100644 --- a/src/agentex/lib/sdk/fastacp/fastacp.py +++ b/src/agentex/lib/sdk/fastacp/fastacp.py @@ -2,7 +2,7 @@ import os import inspect -from typing import Literal +from typing import Any, Literal from pathlib import Path from typing_extensions import deprecated @@ -88,7 +88,10 @@ def locate_build_info_path() -> None: @staticmethod def create( - acp_type: Literal["sync", "async", "agentic"], config: BaseACPConfig | None = None, **kwargs + acp_type: Literal["sync", "async", "agentic"], + config: BaseACPConfig | None = None, + agent_card: Any | None = None, + **kwargs, ) -> BaseACPServer | SyncACP | AsyncBaseACP | TemporalACP: """Main factory method to create any ACP type @@ -102,10 +105,17 @@ def create( if acp_type == "sync": sync_config = config if isinstance(config, SyncACPConfig) else None - return FastACP.create_sync_acp(sync_config, **kwargs) + instance = FastACP.create_sync_acp(sync_config, **kwargs) elif acp_type == "async" or acp_type == "agentic": if config is None: config = AsyncACPConfig(type="base") if not isinstance(config, AsyncACPConfig): raise ValueError("AsyncACPConfig is required for async/agentic ACP type") - return FastACP.create_async_acp(config, **kwargs) + instance = FastACP.create_async_acp(config, **kwargs) + else: + raise ValueError(f"Unknown acp_type: {acp_type}") + + if agent_card is not None: + instance._agent_card = agent_card # type: ignore[attr-defined] + + return instance diff --git a/src/agentex/lib/sdk/state_machine/__init__.py b/src/agentex/lib/sdk/state_machine/__init__.py index 92dc35fea..6013d28f6 100644 --- a/src/agentex/lib/sdk/state_machine/__init__.py +++ b/src/agentex/lib/sdk/state_machine/__init__.py @@ -1,6 +1,16 @@ +from agentex.lib.types.agent_card import AgentCard, AgentLifecycle, LifecycleState + from .state import State from .noop_workflow import NoOpWorkflow from .state_machine import StateMachine from .state_workflow import StateWorkflow -__all__ = ["StateMachine", "StateWorkflow", "State", "NoOpWorkflow"] +__all__ = [ + "StateMachine", + "StateWorkflow", + "State", + "NoOpWorkflow", + "AgentCard", + "AgentLifecycle", + "LifecycleState", +] diff --git a/src/agentex/lib/sdk/state_machine/state_machine.py b/src/agentex/lib/sdk/state_machine/state_machine.py index 6f2acded7..f1e5c4239 100644 --- a/src/agentex/lib/sdk/state_machine/state_machine.py +++ b/src/agentex/lib/sdk/state_machine/state_machine.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Generic, TypeVar from agentex.lib import adk @@ -129,6 +130,28 @@ async def reset_to_initial_state(self): span.output = {"output_state": self._initial_state} # type: ignore[assignment,union-attr] await adk.tracing.end_span(trace_id=self._task_id, span=span) + def get_lifecycle(self) -> dict[str, Any]: + """Export the state machine's lifecycle as a dict suitable for AgentCard.""" + states = [] + for state in self._state_map.values(): + workflow = state.workflow + states.append({ + "name": state.name, + "description": workflow.description, + "waits_for_input": workflow.waits_for_input, + "accepts": list(workflow.accepts), + "transitions": [ + t.value if isinstance(t, Enum) else str(t) + for t in workflow.transitions + ], + }) + initial: str = self._initial_state.value if isinstance(self._initial_state, Enum) else self._initial_state + + return { + "states": states, + "initial_state": initial, + } + def dump(self) -> dict[str, Any]: """ Save the current state of the state machine to a serializable dictionary. diff --git a/src/agentex/lib/sdk/state_machine/state_workflow.py b/src/agentex/lib/sdk/state_machine/state_workflow.py index cca7f46ad..dc5f5ff83 100644 --- a/src/agentex/lib/sdk/state_machine/state_workflow.py +++ b/src/agentex/lib/sdk/state_machine/state_workflow.py @@ -11,6 +11,11 @@ class StateWorkflow(ABC): + description: str = "" + waits_for_input: bool = False + accepts: list[str] = [] + transitions: list[str] = [] + @abstractmethod async def execute( self, state_machine: "StateMachine", state_machine_data: BaseModel | None = None diff --git a/src/agentex/lib/types/agent_card.py b/src/agentex/lib/types/agent_card.py new file mode 100644 index 000000000..def4464c6 --- /dev/null +++ b/src/agentex/lib/types/agent_card.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import types +import typing +from enum import Enum +from typing import TYPE_CHECKING, Any, get_args, get_origin + +from pydantic import BaseModel + +if TYPE_CHECKING: + from agentex.lib.sdk.state_machine.state import State + + +class LifecycleState(BaseModel): + name: str + description: str = "" + waits_for_input: bool = False + accepts: list[str] = [] + transitions: list[str] = [] + + +class AgentLifecycle(BaseModel): + states: list[LifecycleState] + initial_state: str + queries: list[str] = [] + + +class AgentCard(BaseModel): + protocol: str = "acp" + lifecycle: AgentLifecycle | None = None + data_events: list[str] = [] + input_types: list[str] = [] + output_schema: dict | None = None + + @classmethod + def from_states( + cls, + initial_state: str | Enum, + states: list[State], + output_event_model: type[BaseModel] | None = None, + extra_input_types: list[str] | None = None, + queries: list[str] | None = None, + ) -> AgentCard: + """Build an AgentCard directly from a list[State] + initial_state. + + Agents can share their `states` list between the StateMachine and acp.py + without constructing a temporary StateMachine instance. + """ + lifecycle_states = [ + LifecycleState( + name=state.name, + description=state.workflow.description, + waits_for_input=state.workflow.waits_for_input, + accepts=list(state.workflow.accepts), + transitions=[ + t.value if isinstance(t, Enum) else str(t) + for t in state.workflow.transitions + ], + ) + for state in states + ] + + initial = initial_state.value if isinstance(initial_state, Enum) else initial_state + + data_events: list[str] = [] + output_schema: dict | None = None + if output_event_model: + data_events = extract_literal_values(output_event_model, "type") + output_schema = output_event_model.model_json_schema() + + derived_input_types: set[str] = set() + for ls in lifecycle_states: + derived_input_types.update(ls.accepts) + + return cls( + lifecycle=AgentLifecycle( + states=lifecycle_states, + initial_state=initial, + queries=queries or [], + ), + data_events=data_events, + input_types=sorted(derived_input_types | set(extra_input_types or [])), + output_schema=output_schema, + ) + + @classmethod + def from_state_machine( + cls, + state_machine: Any, + output_event_model: type[BaseModel] | None = None, + extra_input_types: list[str] | None = None, + queries: list[str] | None = None, + ) -> AgentCard: + """Build an AgentCard from a StateMachine instance. Delegates to from_states().""" + lifecycle = state_machine.get_lifecycle() + states_data = lifecycle["states"] + initial = lifecycle["initial_state"] + + # Reconstruct lightweight State-like objects from the lifecycle dict + # so we can reuse from_states logic via the dict path + data_events: list[str] = [] + output_schema: dict | None = None + if output_event_model: + data_events = extract_literal_values(output_event_model, "type") + output_schema = output_event_model.model_json_schema() + + derived_input_types: set[str] = set() + lifecycle_states = [] + for s in states_data: + derived_input_types.update(s.get("accepts", [])) + lifecycle_states.append(LifecycleState( + name=s["name"], + description=s.get("description", ""), + waits_for_input=s.get("waits_for_input", False), + accepts=s.get("accepts", []), + transitions=s.get("transitions", []), + )) + + return cls( + lifecycle=AgentLifecycle( + states=lifecycle_states, + initial_state=initial, + queries=queries or [], + ), + data_events=data_events, + input_types=sorted(derived_input_types | set(extra_input_types or [])), + output_schema=output_schema, + ) + + +def extract_literal_values(model: type[BaseModel], field: str) -> list[str]: + """Extract allowed values from a Literal[...] type annotation on a Pydantic model field.""" + field_info = model.model_fields.get(field) + if field_info is None: + return [] + + annotation = field_info.annotation + if annotation is None: + return [] + + # Unwrap Optional (Union[X, None] or PEP 604 X | None) to get the inner type + if get_origin(annotation) is typing.Union or isinstance(annotation, types.UnionType): + args = [a for a in get_args(annotation) if a is not type(None)] + annotation = args[0] if len(args) == 1 else annotation + + if get_origin(annotation) is typing.Literal: + return list(get_args(annotation)) + + return [] diff --git a/src/agentex/lib/utils/registration.py b/src/agentex/lib/utils/registration.py index 30dd836f7..c5823f90a 100644 --- a/src/agentex/lib/utils/registration.py +++ b/src/agentex/lib/utils/registration.py @@ -31,7 +31,7 @@ def get_build_info(): except Exception: return None -async def register_agent(env_vars: EnvironmentVariables): +async def register_agent(env_vars: EnvironmentVariables, agent_card=None): """Register this agent with the Agentex server""" if not env_vars.AGENTEX_BASE_URL: logger.warning("AGENTEX_BASE_URL is not set, skipping registration") @@ -45,13 +45,20 @@ async def register_agent(env_vars: EnvironmentVariables): ) # Prepare registration data + registration_metadata = get_build_info() + if agent_card is not None: + card_data = agent_card.model_dump() if hasattr(agent_card, "model_dump") else agent_card + if registration_metadata is None: + registration_metadata = {} + registration_metadata["agent_card"] = card_data + registration_data = { "name": env_vars.AGENT_NAME, "description": description, "acp_url": full_acp_url, "acp_type": env_vars.ACP_TYPE, "principal_context": get_auth_principal(env_vars), - "registration_metadata": get_build_info() + "registration_metadata": registration_metadata, } if env_vars.AGENT_ID: diff --git a/tests/lib/test_agent_card.py b/tests/lib/test_agent_card.py new file mode 100644 index 000000000..746b6edf0 --- /dev/null +++ b/tests/lib/test_agent_card.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +from enum import Enum +from typing import Literal, override +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel + +from agentex.lib.types.agent_card import AgentCard, extract_literal_values +from agentex.lib.sdk.state_machine import State, StateMachine, StateWorkflow +from agentex.lib.utils.model_utils import BaseModel as AgentexBaseModel + +# --- Fixtures & helpers --- + +class SampleState(str, Enum): + WAITING = "waiting" + PROCESSING = "processing" + DONE = "done" + + +class WaitingWorkflow(StateWorkflow): + description = "Waiting for input" + waits_for_input = True + accepts = ["text", "doc_upload"] + transitions = [SampleState.PROCESSING] + + @override + async def execute(self, state_machine, state_machine_data=None): + return SampleState.PROCESSING + + +class ProcessingWorkflow(StateWorkflow): + description = "Processing data" + accepts = ["text"] + transitions = [SampleState.DONE, SampleState.WAITING] + + @override + async def execute(self, state_machine, state_machine_data=None): + return SampleState.DONE + + +class DoneWorkflow(StateWorkflow): + description = "Terminal state" + transitions = [] + + @override + async def execute(self, state_machine, state_machine_data=None): + return SampleState.DONE + + +class SampleData(AgentexBaseModel): + pass + + +class SampleStateMachine(StateMachine[SampleData]): + @override + async def terminal_condition(self): + return self.get_current_state() == SampleState.DONE + + +class SampleOutputEvent(BaseModel): + type: Literal["plan_update", "status_change", "report_done"] + data: dict = {} + + +@pytest.fixture +def sample_states(): + return [ + State(name=SampleState.WAITING, workflow=WaitingWorkflow()), + State(name=SampleState.PROCESSING, workflow=ProcessingWorkflow()), + State(name=SampleState.DONE, workflow=DoneWorkflow()), + ] + + +@pytest.fixture +def sample_sm(sample_states): + return SampleStateMachine(initial_state=SampleState.WAITING, states=sample_states) + + +# --- extract_literal_values --- + +class TestExtractLiteralValues: + def test_literal_field(self): + class M(BaseModel): + type: Literal["a", "b", "c"] + + assert extract_literal_values(M, "type") == ["a", "b", "c"] + + def test_optional_literal_field(self): + """typing.Optional[Literal[...]] should unwrap correctly.""" + class M(BaseModel): + type: Literal["x", "y"] | None = None + + result = extract_literal_values(M, "type") + assert result == ["x", "y"] + + def test_non_literal_field(self): + class M(BaseModel): + name: str + + assert extract_literal_values(M, "name") == [] + + def test_missing_field(self): + class M(BaseModel): + name: str + + assert extract_literal_values(M, "nonexistent") == [] + + def test_int_literal(self): + class M(BaseModel): + code: Literal[1, 2, 3] + + assert extract_literal_values(M, "code") == [1, 2, 3] + + +# --- StateWorkflow defaults --- + +class TestStateWorkflowDefaults: + def test_default_attrs(self): + assert StateWorkflow.description == "" + assert StateWorkflow.waits_for_input is False + assert StateWorkflow.accepts == [] + assert StateWorkflow.transitions == [] + + def test_subclass_overrides(self): + assert WaitingWorkflow.description == "Waiting for input" + assert WaitingWorkflow.waits_for_input is True + assert WaitingWorkflow.accepts == ["text", "doc_upload"] + assert WaitingWorkflow.transitions == [SampleState.PROCESSING] + + def test_subclass_defaults_not_shared(self): + """Each subclass's list attrs are independent objects.""" + assert WaitingWorkflow.accepts is not ProcessingWorkflow.accepts + assert WaitingWorkflow.transitions is not ProcessingWorkflow.transitions + + +# --- StateMachine.get_lifecycle --- + +class TestGetLifecycle: + def test_structure(self, sample_sm): + lifecycle = sample_sm.get_lifecycle() + + assert "states" in lifecycle + assert "initial_state" in lifecycle + assert lifecycle["initial_state"] == "waiting" + assert len(lifecycle["states"]) == 3 + + def test_state_fields(self, sample_sm): + lifecycle = sample_sm.get_lifecycle() + states_by_name = {s["name"]: s for s in lifecycle["states"]} + + waiting = states_by_name["waiting"] + assert waiting["description"] == "Waiting for input" + assert waiting["waits_for_input"] is True + assert waiting["accepts"] == ["text", "doc_upload"] + assert waiting["transitions"] == ["processing"] + + processing = states_by_name["processing"] + assert processing["description"] == "Processing data" + assert processing["waits_for_input"] is False + assert processing["accepts"] == ["text"] + assert set(processing["transitions"]) == {"done", "waiting"} + + def test_enum_values_resolved(self, sample_sm): + """Enum state names and transitions should be resolved to .value strings.""" + lifecycle = sample_sm.get_lifecycle() + for state in lifecycle["states"]: + assert isinstance(state["name"], str) + for t in state["transitions"]: + assert isinstance(t, str) + + +# --- AgentCard direct construction --- + +class TestAgentCardDirect: + def test_simple_agent(self): + card = AgentCard(input_types=["text"], data_events=["result"]) + assert card.protocol == "acp" + assert card.lifecycle is None + assert card.input_types == ["text"] + assert card.data_events == ["result"] + assert card.output_schema is None + + def test_defaults(self): + card = AgentCard() + assert card.protocol == "acp" + assert card.lifecycle is None + assert card.data_events == [] + assert card.input_types == [] + assert card.output_schema is None + + def test_serialization_roundtrip(self): + card = AgentCard(input_types=["text"], data_events=["result"]) + dumped = card.model_dump() + restored = AgentCard.model_validate(dumped) + assert restored == card + + +# --- AgentCard.from_states --- + +class TestAgentCardFromStates: + def test_lifecycle_derivation(self, sample_states): + card = AgentCard.from_states(initial_state=SampleState.WAITING, states=sample_states) + + assert card.lifecycle is not None + assert card.lifecycle.initial_state == "waiting" + assert len(card.lifecycle.states) == 3 + + def test_initial_state_string(self, sample_states): + card = AgentCard.from_states(initial_state="waiting", states=sample_states) + assert card.lifecycle is not None + assert card.lifecycle.initial_state == "waiting" + + def test_input_types_union(self, sample_states): + card = AgentCard.from_states(initial_state=SampleState.WAITING, states=sample_states) + assert card.input_types == ["doc_upload", "text"] + + def test_extra_input_types(self, sample_states): + card = AgentCard.from_states( + initial_state=SampleState.WAITING, + states=sample_states, + extra_input_types=["admin_command"], + ) + assert card.input_types == ["admin_command", "doc_upload", "text"] + + def test_data_events_and_schema(self, sample_states): + card = AgentCard.from_states( + initial_state=SampleState.WAITING, + states=sample_states, + output_event_model=SampleOutputEvent, + queries=["get_current_state"], + ) + assert card.data_events == ["plan_update", "status_change", "report_done"] + assert card.output_schema is not None + assert card.lifecycle is not None + assert card.lifecycle.queries == ["get_current_state"] + + def test_state_fields(self, sample_states): + card = AgentCard.from_states(initial_state=SampleState.WAITING, states=sample_states) + assert card.lifecycle is not None + states_by_name = {s.name: s for s in card.lifecycle.states} + + waiting = states_by_name["waiting"] + assert waiting.description == "Waiting for input" + assert waiting.waits_for_input is True + assert waiting.accepts == ["text", "doc_upload"] + assert waiting.transitions == ["processing"] + + def test_matches_from_state_machine(self, sample_states, sample_sm): + """from_states and from_state_machine should produce identical cards.""" + card_states = AgentCard.from_states( + initial_state=SampleState.WAITING, + states=sample_states, + output_event_model=SampleOutputEvent, + queries=["get_current_state"], + ) + card_sm = AgentCard.from_state_machine( + state_machine=sample_sm, + output_event_model=SampleOutputEvent, + queries=["get_current_state"], + ) + assert card_states == card_sm + + +# --- AgentCard.from_state_machine --- + +class TestAgentCardFromStateMachine: + def test_lifecycle_derivation(self, sample_sm): + card = AgentCard.from_state_machine(state_machine=sample_sm) + + assert card.lifecycle is not None + assert card.lifecycle.initial_state == "waiting" + assert len(card.lifecycle.states) == 3 + + def test_input_types_union(self, sample_sm): + """input_types should be the sorted union of all per-state accepts.""" + card = AgentCard.from_state_machine(state_machine=sample_sm) + assert card.input_types == ["doc_upload", "text"] + + def test_extra_input_types(self, sample_sm): + card = AgentCard.from_state_machine( + state_machine=sample_sm, + extra_input_types=["admin_command"], + ) + assert "admin_command" in card.input_types + assert card.input_types == ["admin_command", "doc_upload", "text"] + + def test_data_events_extraction(self, sample_sm): + card = AgentCard.from_state_machine( + state_machine=sample_sm, + output_event_model=SampleOutputEvent, + ) + assert card.data_events == ["plan_update", "status_change", "report_done"] + + def test_output_schema_generation(self, sample_sm): + card = AgentCard.from_state_machine( + state_machine=sample_sm, + output_event_model=SampleOutputEvent, + ) + assert card.output_schema is not None + assert "properties" in card.output_schema + assert "type" in card.output_schema["properties"] + + def test_queries(self, sample_sm): + card = AgentCard.from_state_machine( + state_machine=sample_sm, + queries=["get_current_state", "get_progress"], + ) + assert card.lifecycle is not None + assert card.lifecycle.queries == ["get_current_state", "get_progress"] + + def test_no_output_model(self, sample_sm): + card = AgentCard.from_state_machine(state_machine=sample_sm) + assert card.data_events == [] + assert card.output_schema is None + + +# --- register_agent agent_card merging --- + +class TestRegisterAgentCardMerge: + @pytest.fixture + def mock_env_vars(self): + """Minimal EnvironmentVariables mock for register_agent.""" + mock = type("EnvVars", (), { + "AGENTEX_BASE_URL": "http://localhost:5003", + "ACP_URL": "http://localhost", + "ACP_PORT": "8000", + "AGENT_NAME": "test-agent", + "AGENT_DESCRIPTION": "Test agent", + "ACP_TYPE": "sync", + "AUTH_PRINCIPAL_B64": None, + "AGENT_ID": None, + "AGENT_INPUT_TYPE": None, + "AGENT_API_KEY": None, + })() + return mock + + def _make_mock_client(self): + """Create a mock httpx.AsyncClient that returns a successful registration response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + # httpx Response.json() is sync, not async + mock_response.json.return_value = { + "id": "agent-123", + "name": "test-agent", + "agent_api_key": "key-123", + } + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + return mock_client + + async def test_agent_card_merged_into_metadata(self, mock_env_vars): + card = AgentCard(input_types=["text"], data_events=["result"]) + mock_client = self._make_mock_client() + + with patch("agentex.lib.utils.registration.get_build_info", return_value={"version": "1.0"}): + with patch("agentex.lib.utils.registration.httpx.AsyncClient", return_value=mock_client): + from agentex.lib.utils.registration import register_agent + await register_agent(mock_env_vars, agent_card=card) + + sent_data = mock_client.post.call_args.kwargs["json"] + metadata = sent_data["registration_metadata"] + + assert "agent_card" in metadata + assert metadata["agent_card"]["input_types"] == ["text"] + assert metadata["agent_card"]["data_events"] == ["result"] + assert metadata["version"] == "1.0" + + async def test_none_preserved_when_no_card_no_build_info(self, mock_env_vars): + mock_client = self._make_mock_client() + + with patch("agentex.lib.utils.registration.get_build_info", return_value=None): + with patch("agentex.lib.utils.registration.httpx.AsyncClient", return_value=mock_client): + from agentex.lib.utils.registration import register_agent + await register_agent(mock_env_vars, agent_card=None) + + sent_data = mock_client.post.call_args.kwargs["json"] + assert sent_data["registration_metadata"] is None + + async def test_card_creates_metadata_when_build_info_none(self, mock_env_vars): + card = AgentCard(input_types=["text"]) + mock_client = self._make_mock_client() + + with patch("agentex.lib.utils.registration.get_build_info", return_value=None): + with patch("agentex.lib.utils.registration.httpx.AsyncClient", return_value=mock_client): + from agentex.lib.utils.registration import register_agent + await register_agent(mock_env_vars, agent_card=card) + + sent_data = mock_client.post.call_args.kwargs["json"] + metadata = sent_data["registration_metadata"] + assert metadata is not None + assert "agent_card" in metadata