Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/controllers/common/controller_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class WorkflowListQuery(BaseModel):
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
credential_overrides: dict[str, str] | None = Field(
default=None,
description="Optional mapping of provider name to credential ID for overriding default credentials per provider.",
)


class WorkflowUpdatePayload(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class BaseMessagePayload(BaseModel):
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")


class CompletionMessagePayload(BaseMessagePayload):
Expand Down
2 changes: 2 additions & 0 deletions api/controllers/console/explore/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class CompletionMessageExplorePayload(BaseModel):
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="explore_app")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")


class ChatMessagePayload(BaseModel):
Expand All @@ -57,6 +58,7 @@ class ChatMessagePayload(BaseModel):
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")

@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
Expand Down
1 change: 1 addition & 0 deletions api/controllers/openapi/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class AppRunRequest(BaseModel):
auto_generate_name: bool = True
workflow_id: str | None = None
workspace_id: UUIDStrOrEmpty | None = None
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")

@field_validator("conversation_id", mode="before")
@classmethod
Expand Down
2 changes: 2 additions & 0 deletions api/controllers/service_api/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CompletionRequestPayload(BaseModel):
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="dev")
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")


class ChatRequestPayload(BaseModel):
Expand All @@ -69,6 +70,7 @@ class ChatRequestPayload(BaseModel):
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
trace_session_id: str | None = Field(default=None, description="Trace session ID for observability grouping")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")

@field_validator("conversation_id", mode="before")
@classmethod
Expand Down
2 changes: 2 additions & 0 deletions api/controllers/web/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class CompletionMessagePayload(BaseModel):
default=None, description="Response mode: blocking or streaming"
)
retriever_from: str = Field(default="web_app", description="Source of retriever")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")


class ChatMessagePayload(BaseModel):
Expand All @@ -66,6 +67,7 @@ class ChatMessagePayload(BaseModel):
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
retriever_from: str = Field(default="web_app", description="Source of retriever")
credential_overrides: dict[str, str] | None = Field(default=None, description="Dynamic credential overrides")

@field_validator("conversation_id", "parent_message_id")
@classmethod
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/workflow/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def generate(
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
credential_overrides=args.get("credential_overrides"),
)

contexts.plugin_tool_providers.set({})
Expand Down
2 changes: 2 additions & 0 deletions api/core/app/apps/workflow/app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def run(self):
invoke_from=invoke_from,
root_node_id=self._root_node_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
credential_overrides=self.application_generate_entity.credential_overrides,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
Expand Down Expand Up @@ -131,6 +132,7 @@ def run(self):
invoke_from=invoke_from,
root_node_id=root_node_id,
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
credential_overrides=self.application_generate_entity.credential_overrides,
)

# RUN WORKFLOW
Expand Down
2 changes: 2 additions & 0 deletions api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _init_graph(
user_id: str = "",
root_node_id: str | None = None,
trace_session_id: str | None = None,
credential_overrides: dict[str, str] | None = None,
) -> Graph:
"""
Init graph
Expand All @@ -140,6 +141,7 @@ def _init_graph(
user_from=user_from,
invoke_from=invoke_from,
trace_session_id=trace_session_id,
credential_overrides=credential_overrides,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow_id,
Expand Down
4 changes: 4 additions & 0 deletions api/core/app/entities/app_invoke_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DifyRunContext(BaseModel):
user_from: UserFrom
invoke_from: InvokeFrom
trace_session_id: str | None = None
credential_overrides: dict[str, str] | None = None


def build_dify_run_context(
Expand All @@ -65,6 +66,7 @@ def build_dify_run_context(
user_from: UserFrom,
invoke_from: InvokeFrom,
trace_session_id: str | None = None,
credential_overrides: dict[str, str] | None = None,
extra_context: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
"""
Expand All @@ -81,6 +83,7 @@ def build_dify_run_context(
user_from=user_from,
invoke_from=invoke_from,
trace_session_id=trace_session_id,
credential_overrides=credential_overrides,
)
return run_context

Expand Down Expand Up @@ -258,6 +261,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
# app config
app_config: WorkflowUIBasedAppConfig = None # type: ignore
workflow_execution_id: str
credential_overrides: dict[str, str] | None = None

class SingleIterationRunEntity(BaseModel):
"""
Expand Down
84 changes: 82 additions & 2 deletions api/core/app/llm/model_access.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from __future__ import annotations

import json
import logging
from copy import deepcopy
from typing import Any

from sqlalchemy import select

from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
from core.db.session_factory import session_factory
from core.errors.error import ProviderTokenNotInitError
from core.helper import encrypter
from core.model_manager import ModelInstance, ModelManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.provider_manager import ProviderManager
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.nodes.llm.entities import ModelConfig
from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from graphon.nodes.llm.protocols import CredentialsProvider
from models.provider import ProviderCredential

logger = logging.getLogger(__name__)


class DifyCredentialsProvider:
Expand All @@ -36,8 +45,10 @@ def __init__(
*,
run_context: DifyRunContext,
provider_manager: ProviderManager | None = None,
credential_overrides: dict[str, str] | None = None,
) -> None:
self.tenant_id = run_context.tenant_id
self.credential_overrides = credential_overrides or {}
if provider_manager is None:
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
Expand All @@ -47,6 +58,21 @@ def __init__(
self.credentials_cache = {}

def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
# Check for per-run credential override for this provider
override_credential_id = self.credential_overrides.get(provider_name)
if override_credential_id:
cache_key_override = (f"__override__{provider_name}", override_credential_id)
if cache_key_override in self.credentials_cache:
return deepcopy(self.credentials_cache[cache_key_override])

credentials = self._fetch_by_credential_id(
credential_id=override_credential_id,
provider_name=provider_name,
)
self.credentials_cache[cache_key_override] = deepcopy(credentials)
return credentials

# Default resolution logic
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])

Expand All @@ -67,6 +93,52 @@ def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials

def _fetch_by_credential_id(
self,
credential_id: str,
provider_name: str,
) -> dict[str, Any]:
"""Resolve and decrypt credentials from a specific ProviderCredential record.

Security: The credential must belong to the same tenant as the current run.
"""
with session_factory.create_session() as session:
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
)
credential_record = session.scalar(stmt)

if credential_record is None:
raise ProviderTokenNotInitError(
f"Credential '{credential_id}' not found or does not belong to this workspace."
)

if credential_record.provider_name != provider_name:
raise ValueError(
f"Credential '{credential_id}' belongs to provider '{credential_record.provider_name}', "
f"but was requested for provider '{provider_name}'."
)

# Decrypt the encrypted_config JSON blob
try:
credentials: dict[str, Any] = json.loads(credential_record.encrypted_config)
except (json.JSONDecodeError, TypeError) as e:
raise ProviderTokenNotInitError(
f"Failed to parse credentials for credential '{credential_id}': {e}"
)

# Decrypt secret fields using tenant RSA key
for key, value in credentials.items():
if isinstance(value, str) and value:
try:
credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=value)
except Exception:
# Not an encrypted value or not valid base64, keep as-is
pass

return credentials


class DifyModelFactory:
tenant_id: str
Expand Down Expand Up @@ -98,7 +170,11 @@ def init_model_instance(self, provider_name: str, model_name: str) -> ModelInsta
)


def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]:
def build_dify_model_access(
run_context: DifyRunContext,
*,
credential_overrides: dict[str, str] | None = None,
) -> tuple[CredentialsProvider, DifyModelFactory]:
"""Create LLM access adapters that share the same tenant-bound manager graph."""
provider_manager = create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
Expand All @@ -107,7 +183,11 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)

return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
DifyCredentialsProvider(
run_context=run_context,
provider_manager=provider_manager,
credential_overrides=credential_overrides,
),
DifyModelFactory(run_context=run_context, model_manager=model_manager),
)

Expand Down
5 changes: 4 additions & 1 deletion api/core/workflow/node_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ def __init__(
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)

self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(
self._dify_context,
credential_overrides=self._dify_context.credential_overrides,
)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
Expand Down
Loading