diff --git a/api/controllers/common/controller_schemas.py b/api/controllers/common/controller_schemas.py index c12d57647375e6..f4215e47623c24 100644 --- a/api/controllers/common/controller_schemas.py +++ b/api/controllers/common/controller_schemas.py @@ -63,6 +63,10 @@ class WorkflowListQuery(BaseModel): class WorkflowRunPayload(BaseModel): inputs: dict[str, Any] files: list[dict[str, Any]] | None = None + credential_overrides: dict[str, str] | None = Field( + default=None, + description="Optional mapping of provider name to credential ID for overriding default credentials per provider.", + ) class WorkflowUpdatePayload(BaseModel): diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 78643ad0762af7..2837509934c0ac 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -67,6 +67,7 @@ class BaseMessagePayload(BaseModel): files: list[Any] | None = Field(default=None, description="Uploaded files") response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode") retriever_from: str = Field(default="dev", description="Retriever source") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") class CompletionMessagePayload(BaseMessagePayload): diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index d1ae6526c689f6..806a56440c74fe 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -48,6 +48,7 @@ class CompletionMessageExplorePayload(BaseModel): files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None retriever_from: str = Field(default="explore_app") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") class ChatMessagePayload(BaseModel): @@ -57,6 +58,7 @@ class ChatMessagePayload(BaseModel): conversation_id: str | None = None parent_message_id: str | None = None retriever_from: str = Field(default="explore_app") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") @field_validator("conversation_id", "parent_message_id", mode="before") @classmethod diff --git a/api/controllers/openapi/_models.py b/api/controllers/openapi/_models.py index 59b2e5176ea7fd..8e15b7534e900c 100644 --- a/api/controllers/openapi/_models.py +++ b/api/controllers/openapi/_models.py @@ -283,6 +283,7 @@ class AppRunRequest(BaseModel): auto_generate_name: bool = True workflow_id: str | None = None workspace_id: UUIDStrOrEmpty | None = None + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") @field_validator("conversation_id", mode="before") @classmethod diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index c2294a3fc1c8ec..f167bbe9d657f7 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -57,6 +57,7 @@ class CompletionRequestPayload(BaseModel): response_mode: Literal["blocking", "streaming"] | None = None retriever_from: str = Field(default="dev") trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") class ChatRequestPayload(BaseModel): @@ -69,6 +70,7 @@ class ChatRequestPayload(BaseModel): auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") @field_validator("conversation_id", mode="before") @classmethod diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index d4c02b65921cd3..ec3748318b62bb 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -54,6 +54,7 @@ class CompletionMessagePayload(BaseModel): default=None, description="Response mode: blocking or streaming" ) retriever_from: str = Field(default="web_app", description="Source of retriever") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") class ChatMessagePayload(BaseModel): @@ -66,6 +67,7 @@ class ChatMessagePayload(BaseModel): conversation_id: str | None = Field(default=None, description="Conversation ID") parent_message_id: str | None = Field(default=None, description="Parent message ID") retriever_from: str = Field(default="web_app", description="Source of retriever") + credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides") @field_validator("conversation_id", "parent_message_id") @classmethod diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index fb4f94bc878929..79e2acd79f938f 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -203,6 +203,7 @@ def generate( trace_manager=trace_manager, workflow_execution_id=workflow_run_id, extras=extras, + credential_overrides=args.get("credential_overrides"), ) contexts.plugin_tool_providers.set({}) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index ecb485885ffce9..8d7112aad239ae 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -88,6 +88,7 @@ def run(self): invoke_from=invoke_from, root_node_id=self._root_node_id, trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), + credential_overrides=self.application_generate_entity.credential_overrides, ) elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( @@ -131,6 +132,7 @@ def run(self): invoke_from=invoke_from, root_node_id=root_node_id, trace_session_id=self.application_generate_entity.extras.get("trace_session_id"), + credential_overrides=self.application_generate_entity.credential_overrides, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 944860ee39c5df..7129f5b22a5067 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -119,6 +119,7 @@ def _init_graph( user_id: str = "", root_node_id: str | None = None, trace_session_id: str | None = None, + credential_overrides: dict[str, str] | None = None, ) -> Graph: """ Init graph @@ -140,6 +141,7 @@ def _init_graph( user_from=user_from, invoke_from=invoke_from, trace_session_id=trace_session_id, + credential_overrides=credential_overrides, ) graph_init_context = DifyGraphInitContext( workflow_id=workflow_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 08ecc2097b3284..22e58dd8fcfffe 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -55,6 +55,7 @@ class DifyRunContext(BaseModel): user_from: UserFrom invoke_from: InvokeFrom trace_session_id: str | None = None + credential_overrides: dict[str, str] | None = None def build_dify_run_context( @@ -65,6 +66,7 @@ def build_dify_run_context( user_from: UserFrom, invoke_from: InvokeFrom, trace_session_id: str | None = None, + credential_overrides: dict[str, str] | None = None, extra_context: Mapping[str, Any] | None = None, ) -> dict[str, Any]: """ @@ -81,6 +83,7 @@ def build_dify_run_context( user_from=user_from, invoke_from=invoke_from, trace_session_id=trace_session_id, + credential_overrides=credential_overrides, ) return run_context @@ -258,6 +261,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): # app config app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_execution_id: str + credential_overrides: dict[str, str] | None = None class SingleIterationRunEntity(BaseModel): """ diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 5631caa1a59416..f6b73f80477d05 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -1,10 +1,16 @@ from __future__ import annotations +import json +import logging from copy import deepcopy from typing import Any +from sqlalchemy import select + from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity +from core.db.session_factory import session_factory from core.errors.error import ProviderTokenNotInitError +from core.helper import encrypter from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager @@ -12,6 +18,9 @@ from graphon.nodes.llm.entities import ModelConfig from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from graphon.nodes.llm.protocols import CredentialsProvider +from models.provider import ProviderCredential + +logger = logging.getLogger(__name__) class DifyCredentialsProvider: @@ -36,8 +45,10 @@ def __init__( *, run_context: DifyRunContext, provider_manager: ProviderManager | None = None, + credential_overrides: dict[str, str] | None = None, ) -> None: self.tenant_id = run_context.tenant_id + self.credential_overrides = credential_overrides or {} if provider_manager is None: provider_manager = create_plugin_provider_manager( tenant_id=run_context.tenant_id, @@ -47,6 +58,21 @@ def __init__( self.credentials_cache = {} def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + # Check for per-run credential override for this provider + override_credential_id = self.credential_overrides.get(provider_name) + if override_credential_id: + cache_key_override = (f"__override__{provider_name}", override_credential_id) + if cache_key_override in self.credentials_cache: + return deepcopy(self.credentials_cache[cache_key_override]) + + credentials = self._fetch_by_credential_id( + credential_id=override_credential_id, + provider_name=provider_name, + ) + self.credentials_cache[cache_key_override] = deepcopy(credentials) + return credentials + + # Default resolution logic if (provider_name, model_name) in self.credentials_cache: return deepcopy(self.credentials_cache[(provider_name, model_name)]) @@ -67,6 +93,52 @@ def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) return credentials + def _fetch_by_credential_id( + self, + credential_id: str, + provider_name: str, + ) -> dict[str, Any]: + """Resolve and decrypt credentials from a specific ProviderCredential record. + + Security: The credential must belong to the same tenant as the current run. + """ + with session_factory.create_session() as session: + stmt = select(ProviderCredential).where( + ProviderCredential.id == credential_id, + ProviderCredential.tenant_id == self.tenant_id, + ) + credential_record = session.scalar(stmt) + + if credential_record is None: + raise ProviderTokenNotInitError( + f"Credential '{credential_id}' not found or does not belong to this workspace." + ) + + if credential_record.provider_name != provider_name: + raise ValueError( + f"Credential '{credential_id}' belongs to provider '{credential_record.provider_name}', " + f"but was requested for provider '{provider_name}'." + ) + + # Decrypt the encrypted_config JSON blob + try: + credentials: dict[str, Any] = json.loads(credential_record.encrypted_config) + except (json.JSONDecodeError, TypeError) as e: + raise ProviderTokenNotInitError( + f"Failed to parse credentials for credential '{credential_id}': {e}" + ) + + # Decrypt secret fields using tenant RSA key + for key, value in credentials.items(): + if isinstance(value, str) and value: + try: + credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=value) + except Exception: + # Not an encrypted value or not valid base64, keep as-is + pass + + return credentials + class DifyModelFactory: tenant_id: str @@ -98,7 +170,11 @@ def init_model_instance(self, provider_name: str, model_name: str) -> ModelInsta ) -def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]: +def build_dify_model_access( + run_context: DifyRunContext, + *, + credential_overrides: dict[str, str] | None = None, +) -> tuple[CredentialsProvider, DifyModelFactory]: """Create LLM access adapters that share the same tenant-bound manager graph.""" provider_manager = create_plugin_provider_manager( tenant_id=run_context.tenant_id, @@ -107,7 +183,11 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True) return ( - DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyCredentialsProvider( + run_context=run_context, + provider_manager=provider_manager, + credential_overrides=credential_overrides, + ), DifyModelFactory(run_context=run_context, model_manager=model_manager), ) diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index d02a45a944bec5..9766bca410722f 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -346,7 +346,10 @@ def __init__( ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access( + self._dify_context, + credential_overrides=self._dify_context.credential_overrides, + ) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport()