diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 49195415b..553766811 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -4,6 +4,7 @@ import types from collections.abc import Sequence +from enum import Enum from typing import Generic, Literal, TypeVar, Union, get_args, get_origin from pydantic import BaseModel @@ -46,7 +47,7 @@ class AcceptedUrlElicitation(BaseModel): # Primitive types allowed in elicitation schemas -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool, Enum) def _validate_elicitation_schema(schema: type[BaseModel]) -> None: @@ -99,6 +100,10 @@ def _is_primitive_field(annotation: type) -> bool: arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args ) + # Handle Enum types + if isinstance(annotation, type) and issubclass(annotation, str) and issubclass(annotation, Enum): # type: ignore[arg-type] + return True + return False diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 597b29178..5d3b1a29a 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -2,6 +2,7 @@ Test the elicitation feature using stdio transport. """ +from enum import Enum from typing import Any import pytest @@ -147,6 +148,39 @@ async def elicitation_callback( assert "Validation failed as expected" in result.content[0].text assert field_name in result.content[0].text + # Test valid Enum types (should not fail validation) + class Status(str, Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + class ValidStrEnumSchema(BaseModel): + status: Status = Field(description="Status using StrEnum") + + def create_valid_validation_tool(name: str, schema_class: type[BaseModel]): + @mcp.tool(name=name, description=f"Tool testing {name}") + async def tool(ctx: Context[ServerSession, None]) -> str: + # This should succeed without validation error + result = await ctx.elicit(message="Test valid schema", schema=schema_class) + return f"Success: {result.action}" + + return tool + + create_valid_validation_tool("valid_strenum", ValidStrEnumSchema) + + async def enum_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + # Return the required status field + return ElicitResult(action="accept", content={"status": "active"}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=enum_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool("valid_strenum", {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Success: accept" == result.content[0].text + @pytest.mark.anyio async def test_elicitation_with_optional_fields():