Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
setup_required,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
Expand Down Expand Up @@ -210,6 +210,30 @@ class MCPProviderBasePayload(BaseModel):
configuration: dict[str, Any] | None = Field(default_factory=dict)
headers: dict[str, Any] | None = Field(default_factory=dict)
authentication: dict[str, Any] | None = Field(default_factory=dict)
# None means "leave unchanged" on update; the controller resolves it to a
# concrete IdentityMode before calling the service (see _resolve_identity_mode).
identity_mode: IdentityMode | None = None


def _resolve_identity_mode(requested: IdentityMode | None, *, current: IdentityMode) -> IdentityMode:
"""Resolve the effective MCP identity_mode for a create/update request.

Keeps two API-layer concerns out of the service so the service always
receives a concrete value:

* ``None`` means "leave unchanged" (update semantics) — fall back to
``current`` (``IdentityMode.OFF`` for a brand-new provider).
* Identity forwarding is an enterprise-only capability. On non-enterprise
deployments any non-OFF value is coerced back to OFF so a persisted row
can never imply forwarding that the runtime won't perform. This gates the
API surface to match the backend gate in
``MCPTool._forwarding_requested`` — both the API and the backend
invocation must be gated on ``dify_config.ENTERPRISE_ENABLED``.
"""
mode = current if requested is None else requested
if mode != IdentityMode.OFF and not dify_config.ENTERPRISE_ENABLED:
return IdentityMode.OFF
return mode


class MCPProviderCreatePayload(MCPProviderBasePayload):
Expand Down Expand Up @@ -1000,6 +1024,7 @@ def post(self):
headers=payload.headers or {},
configuration=configuration,
authentication=authentication,
identity_mode=_resolve_identity_mode(payload.identity_mode, current=IdentityMode.OFF),
)

# 2) Try to fetch tools immediately after creation so they appear without a second save.
Expand Down Expand Up @@ -1054,6 +1079,11 @@ def put(self):
# Step 3: Perform database update in a transaction
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
# Resolve "leave unchanged" (None) against the stored value, and gate
# the result on ENTERPRISE_ENABLED — both are API-layer concerns, so
# the service receives a concrete IdentityMode.
existing = service.get_provider(provider_id=payload.provider_id, tenant_id=current_tenant_id)
identity_mode = _resolve_identity_mode(payload.identity_mode, current=IdentityMode(existing.identity_mode))
service.update_provider(
tenant_id=current_tenant_id,
provider_id=payload.provider_id,
Expand All @@ -1067,6 +1097,7 @@ def put(self):
configuration=configuration,
authentication=authentication,
validation_result=validation_result,
identity_mode=identity_mode,
)

return {"result": "success"}
Expand Down
11 changes: 11 additions & 0 deletions api/core/entities/mcp_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ class MCPSupportGrantType(StrEnum):
REFRESH_TOKEN = "refresh_token"


class IdentityMode(StrEnum):
"""How Dify forwards the end-user's identity to an MCP server."""

OFF = "off"
IDP_TOKEN = "idp_token"


class MCPAuthentication(BaseModel):
client_id: str
client_secret: str | None = None
Expand Down Expand Up @@ -76,6 +83,8 @@ class MCPProviderEntity(BaseModel):
created_at: datetime
updated_at: datetime

identity_mode: IdentityMode = IdentityMode.OFF

@classmethod
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
Expand All @@ -96,6 +105,7 @@ def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
icon=db_provider.icon or "",
created_at=db_provider.created_at,
updated_at=db_provider.updated_at,
identity_mode=IdentityMode(db_provider.identity_mode),
)

@property
Expand Down Expand Up @@ -170,6 +180,7 @@ def to_api_response(self, user_name: str | None = None, include_sensitive: bool
"updated_at": int(self.updated_at.timestamp()),
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
"identity_mode": self.identity_mode,
}

# Add configuration
Expand Down
6 changes: 6 additions & 0 deletions api/core/mcp/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
provider_entity: MCPProviderEntity | None = None,
authorization_code: str | None = None,
by_server_id: bool = False,
forward_identity_active: bool = False,
):
"""
Initialize the MCP client with auth retry capability.
Expand All @@ -52,12 +53,15 @@ def __init__(
provider_entity: Provider entity for authentication
authorization_code: Optional authorization code for initial auth
by_server_id: Whether to look up provider by server ID
forward_identity_active: If True, suppress the static-OAuth retry
on 401 — the forwarded identity must propagate as-is.
"""
super().__init__(server_url, headers, timeout, sse_read_timeout)

self.provider_entity = provider_entity
self.authorization_code = authorization_code
self.by_server_id = by_server_id
self.forward_identity_active = forward_identity_active
self._has_retried = False

def _handle_auth_error(self, error: MCPAuthError) -> None:
Expand All @@ -73,6 +77,8 @@ def _handle_auth_error(self, error: MCPAuthError) -> None:
Raises:
MCPAuthError: If authentication fails or max retries reached
"""
if self.forward_identity_active:
raise error
if not self.provider_entity:
raise error
if self._has_retried:
Expand Down
6 changes: 6 additions & 0 deletions api/core/tools/entities/api_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class ToolProviderApiEntity(BaseModel):
configuration: MCPConfiguration | None = Field(
default=None, description="The timeout and sse_read_timeout of the MCP tool"
)
# M3 — user-identity forwarding selector. Round-tripped through the
# console API so the create/edit modal can hydrate the toggle state.
identity_mode: str = Field(default="off", description="Identity-forwarding mechanism: 'off' or 'idp_token'")
# Workflow
workflow_app_id: str | None = Field(default=None, description="The app id of the workflow tool")

Expand Down Expand Up @@ -92,6 +95,9 @@ def to_dict(self):
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
optional_fields.update(self.optional_field("original_headers", self.original_headers))
# M3 — forwarding selector. Always emit ("off" is a valid
# value that the UI must hydrate, not skip).
optional_fields["identity_mode"] = self.identity_mode
case ToolProviderType.WORKFLOW:
optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id))
case _:
Expand Down
7 changes: 6 additions & 1 deletion api/core/tools/mcp_tool/provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Self

from core.entities.mcp_provider import MCPProviderEntity
from core.entities.mcp_provider import IdentityMode, MCPProviderEntity
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
Expand Down Expand Up @@ -28,6 +28,7 @@ def __init__(
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
identity_mode: IdentityMode = IdentityMode.OFF,
):
super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
Expand All @@ -37,6 +38,7 @@ def __init__(
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.identity_mode: IdentityMode = identity_mode

@property
def provider_type(self) -> ToolProviderType:
Expand Down Expand Up @@ -105,6 +107,7 @@ def from_entity(cls, entity: MCPProviderEntity) -> Self:
headers=entity.headers,
timeout=entity.timeout,
sse_read_timeout=entity.sse_read_timeout,
identity_mode=entity.identity_mode,
)

def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
Expand Down Expand Up @@ -134,6 +137,7 @@ def get_tool(self, tool_name: str) -> MCPTool:
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
identity_mode=self.identity_mode,
)

def get_tools(self) -> list[MCPTool]:
Expand All @@ -151,6 +155,7 @@ def get_tools(self) -> list[MCPTool]:
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
identity_mode=self.identity_mode,
)
for tool_entity in self.entity.tools
]
77 changes: 75 additions & 2 deletions api/core/tools/mcp_tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections.abc import Generator, Mapping
from typing import Any, cast

from configs import dify_config
from core.entities.mcp_provider import IdentityMode
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPConnectionError
from core.mcp.types import (
Expand All @@ -25,6 +27,11 @@

logger = logging.getLogger(__name__)

# Custom header used to carry the forwarded SSO access token. Picked to avoid
# stomping on the workspace-scoped Authorization header (provider OAuth /
# user-supplied custom credentials), which would silently break those flows.
FORWARDED_IDENTITY_HEADER = "X-Dify-SSO-Access-Token"


class MCPTool(Tool):
def __init__(
Expand All @@ -38,6 +45,7 @@ def __init__(
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
identity_mode: IdentityMode = IdentityMode.OFF,
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
Expand All @@ -47,6 +55,7 @@ def __init__(
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.identity_mode: IdentityMode = identity_mode
self._latest_usage = LLMUsage.empty_usage()

def tool_provider_type(self) -> ToolProviderType:
Expand All @@ -60,7 +69,7 @@ def _invoke(
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
result = self.invoke_remote_mcp_tool(tool_parameters)
result = self.invoke_remote_mcp_tool(tool_parameters, user_id=user_id, app_id=app_id)

# Extract usage metadata from MCP protocol's _meta field
self._latest_usage = self._derive_usage_from_result(result)
Expand Down Expand Up @@ -234,6 +243,7 @@ def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
identity_mode=self.identity_mode,
)

def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
Expand All @@ -246,7 +256,26 @@ def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
if value is not None and not (isinstance(value, str) and value.strip() == "")
}

def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
@property
def _forwarding_requested(self) -> bool:
"""True only when the configured identity_mode wants forwarding AND
the deployment actually has the enterprise side that can mint tokens.
Non-enterprise installs treat the DB value as a no-op — a stale row
won't trigger a 5xx against a missing inner-API endpoint."""
return self.identity_mode != IdentityMode.OFF and dify_config.ENTERPRISE_ENABLED

def invoke_remote_mcp_tool(
self,
tool_parameters: dict[str, Any],
user_id: str | None = None,
app_id: str | None = None,
) -> CallToolResult:
# Fail closed: forwarding requires user_id (refuse before any DB I/O).
if self._forwarding_requested and not user_id:
raise ToolInvokeError(
"Forward-user-identity is enabled for this MCP provider but no end-user context was supplied."
)

headers = self.headers.copy() if self.headers else {}
tool_parameters = self._handle_none_parameter(tool_parameters)

Expand All @@ -271,6 +300,15 @@ def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolRes
if tokens and tokens.access_token:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"

# Forwarded identity rides in a custom header so workspace-scoped
# provider credentials (Authorization / custom Headers) keep working
# untouched. The MCP server is expected to read X-Dify-SSO-Access-Token
# when identity forwarding is configured.
forward_identity_active = False
if self._forwarding_requested and user_id:
self._inject_forwarded_identity(headers, user_id=user_id, app_id=app_id, audience=server_url)
forward_identity_active = True

# Step 2: Session is now closed, perform network operations without holding database connection
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
try:
Expand All @@ -280,9 +318,44 @@ def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolRes
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
provider_entity=provider_entity,
forward_identity_active=forward_identity_active,
) as mcp_client:
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPConnectionError as e:
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
except Exception as e:
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e

def _inject_forwarded_identity(
self,
headers: dict[str, str],
*,
user_id: str,
app_id: str | None,
audience: str,
) -> None:
"""Call the enterprise IssueMCPToken endpoint and stamp the issued
token into X-Dify-SSO-Access-Token.

A custom header is used (rather than Authorization) so it composes
with workspace-scoped provider credentials — the user may have OAuth
tokens or a custom Authorization header configured on the MCP
provider, and forwarding must not silently overwrite them.

Errors are surfaced as ToolInvokeError so the workflow halts with a
clear message instead of silently dropping identity and hitting the
MCP server unauthenticated.
"""
from services.enterprise.base import MCPTokenError
from services.enterprise.enterprise_service import EnterpriseService

try:
token, _expires_at = EnterpriseService.issue_mcp_token(
Comment thread
wylswz marked this conversation as resolved.
user_id=user_id,
tenant_id=self.tenant_id,
app_id=app_id,
audience=audience,
)
except MCPTokenError as e:
raise ToolInvokeError(f"Failed to obtain forwarded identity token: {e}") from e
headers[FORWARDED_IDENTITY_HEADER] = token
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""add identity mode to mcp tool provider

Revision ID: 3df4dbcc1e21
Revises: 8d4c2a1b9f03
Create Date: 2026-05-29 15:00:00.000000

Adds the `identity_mode` column to `tool_mcp_providers` to drive the M2 MCP
user-identity forwarding feature. Reserved values:

"off" — no header forwarded (default; pre-M2 behaviour).
"idp_token" — call dify-enterprise /inner/api/mcp/issue-token, stamp the
returned SSO access token on the outbound MCP request as
`X-Dify-SSO-Access-Token: <token>`.

The column is filled with the safe default "off" for existing rows so older
providers keep their current behaviour until an admin opts in.
"""

import sqlalchemy as sa
from alembic import op

import models as models

# revision identifiers, used by Alembic.
revision = "3df4dbcc1e21"
down_revision = "8d4c2a1b9f03"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"tool_mcp_providers",
sa.Column(
"identity_mode",
sa.String(length=32),
nullable=False,
server_default=sa.text("'off'"),
),
)


def downgrade():
op.drop_column("tool_mcp_providers", "identity_mode")
Loading
Loading