From ecb1d93955bd17a8dbef4488893fcd4e65088a58 Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Tue, 19 May 2026 18:25:02 +0800 Subject: [PATCH 1/6] feat: add LLM polling protocol hooks --- .../core/entities/plugin/request.py | 19 + src/dify_plugin/core/plugin_executor.py | 69 ++++ src/dify_plugin/entities/model/llm.py | 49 ++- src/dify_plugin/entities/model/schema.py | 1 + .../interfaces/model/large_language_model.py | 148 ++++++++ src/dify_plugin/plugin.py | 16 + tests/test_model_polling.py | 325 ++++++++++++++++++ 7 files changed, 625 insertions(+), 2 deletions(-) create mode 100644 tests/test_model_polling.py diff --git a/src/dify_plugin/core/entities/plugin/request.py b/src/dify_plugin/core/entities/plugin/request.py index ea6ce963..b7426452 100644 --- a/src/dify_plugin/core/entities/plugin/request.py +++ b/src/dify_plugin/core/entities/plugin/request.py @@ -58,6 +58,8 @@ class ModelActions(StrEnum): ValidateProviderCredentials = "validate_provider_credentials" ValidateModelCredentials = "validate_model_credentials" InvokeLLM = "invoke_llm" + StartPolling = "start_polling" + CheckPolling = "check_polling" GetLLMNumTokens = "get_llm_num_tokens" InvokeTextEmbedding = "invoke_text_embedding" InvokeMultimodalEmbedding = "invoke_multimodal_embedding" @@ -193,11 +195,28 @@ class ModelInvokeLLMRequest(PluginAccessModelRequest, PromptMessageMixin): model_parameters: dict[str, Any] stop: list[str] | None tools: list[PromptMessageTool] | None + json_schema: dict[str, Any] | None = None stream: bool = True model_config = ConfigDict(protected_namespaces=()) +class ModelStartPollingRequest(ModelInvokeLLMRequest): + action: ModelActions = ModelActions.StartPolling + stream: bool = False + + workflow_run_id: str + node_id: str + + +class ModelCheckPollingRequest(PluginAccessModelRequest): + action: ModelActions = ModelActions.CheckPolling + + workflow_run_id: str + node_id: str + plugin_state: dict[str, Any] + + class ModelGetLLMNumTokens(PluginAccessModelRequest, PromptMessageMixin): action: ModelActions = ModelActions.GetLLMNumTokens diff --git a/src/dify_plugin/core/plugin_executor.py b/src/dify_plugin/core/plugin_executor.py index 8e4a48d1..6ae39f19 100644 --- a/src/dify_plugin/core/plugin_executor.py +++ b/src/dify_plugin/core/plugin_executor.py @@ -17,6 +17,7 @@ DatasourceValidateCredentialsRequest, DynamicParameterFetchParameterOptionsRequest, EndpointInvokeRequest, + ModelCheckPollingRequest, ModelGetAIModelSchemas, ModelGetLLMNumTokens, ModelGetTextEmbeddingNumTokens, @@ -29,6 +30,7 @@ ModelInvokeSpeech2TextRequest, ModelInvokeTextEmbeddingRequest, ModelInvokeTTSRequest, + ModelStartPollingRequest, ModelValidateModelCredentialsRequest, ModelValidateProviderCredentialsRequest, OAuthGetAuthorizationUrlRequest, @@ -261,6 +263,73 @@ def invoke_llm(self, session: Session, data: ModelInvokeLLMRequest) -> object: msg, ) + def start_llm_polling( + self, + session: Session, + data: ModelStartPollingRequest, + ) -> object: + del session + model_instance = self.registration.get_model_instance( + data.provider, + data.model_type, + ) + if isinstance(model_instance, LargeLanguageModel): + if not model_instance.supports_polling(data.model, data.credentials): + msg = ( + f"Model `{data.model}` for provider `{data.provider}` " + "does not support polling" + ) + raise ValueError(msg) + + return model_instance.start_polling( + model=data.model, + credentials=data.credentials, + prompt_messages=data.prompt_messages, + model_parameters=data.model_parameters, + tools=data.tools, + stop=data.stop, + stream=data.stream, + user=data.user_id, + json_schema=data.json_schema, + workflow_run_id=data.workflow_run_id, + node_id=data.node_id, + ) + msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" + raise ValueError( + msg, + ) + + def check_llm_polling( + self, + session: Session, + data: ModelCheckPollingRequest, + ) -> object: + del session + model_instance = self.registration.get_model_instance( + data.provider, + data.model_type, + ) + if isinstance(model_instance, LargeLanguageModel): + if not model_instance.supports_polling(data.model, data.credentials): + msg = ( + f"Model `{data.model}` for provider `{data.provider}` " + "does not support polling" + ) + raise ValueError(msg) + + return model_instance.check_polling( + model=data.model, + credentials=data.credentials, + plugin_state=data.plugin_state, + user=data.user_id, + workflow_run_id=data.workflow_run_id, + node_id=data.node_id, + ) + msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" + raise ValueError( + msg, + ) + def get_llm_num_tokens( self, session: Session, diff --git a/src/dify_plugin/entities/model/llm.py b/src/dify_plugin/entities/model/llm.py index 74a54a14..c4cd47ab 100644 --- a/src/dify_plugin/entities/model/llm.py +++ b/src/dify_plugin/entities/model/llm.py @@ -1,8 +1,9 @@ from collections.abc import Mapping from decimal import Decimal -from enum import Enum +from enum import Enum, StrEnum +from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from dify_plugin.entities.model import BaseModelConfig, ModelType, ModelUsage, PriceInfo from dify_plugin.entities.model.message import ( @@ -37,6 +38,12 @@ def value_of(cls, value: str) -> "LLMMode": raise ValueError(msg) +class LLMPollingStatus(StrEnum): + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + + class LLMUsage(ModelUsage): """Model class for llm usage.""" @@ -174,6 +181,44 @@ def to_llm_result_chunk_with_structured_output( ) +class LLMPollingResult(BaseModel): + """Model class for llm polling result.""" + + status: LLMPollingStatus + plugin_state: dict[str, Any] | None = None + result: LLMResult | LLMResultWithStructuredOutput | None = None + error: str | None = None + next_check_after_seconds: int | None = None + expires_after_seconds: int | None = None + max_attempts: int | None = None + + @model_validator(mode="after") + def validate_status_payload(self) -> "LLMPollingResult": + if self.status is LLMPollingStatus.RUNNING and self.plugin_state is None: + msg = "plugin_state is required when polling status is running." + raise ValueError(msg) + + if self.status is LLMPollingStatus.SUCCEEDED and self.result is None: + msg = "result is required when polling status is succeeded." + raise ValueError(msg) + + if self.status is LLMPollingStatus.FAILED and not self.error: + msg = "error is required when polling status is failed." + raise ValueError(msg) + + for field_name in ( + "next_check_after_seconds", + "expires_after_seconds", + "max_attempts", + ): + value = getattr(self, field_name) + if value is not None and value <= 0: + msg = f"{field_name} must be greater than 0." + raise ValueError(msg) + + return self + + class SummaryResult(BaseModel): """Model class for summary result.""" diff --git a/src/dify_plugin/entities/model/schema.py b/src/dify_plugin/entities/model/schema.py index c3825e42..a5b6ce47 100644 --- a/src/dify_plugin/entities/model/schema.py +++ b/src/dify_plugin/entities/model/schema.py @@ -232,6 +232,7 @@ class ModelFeature(Enum): VIDEO = "video" AUDIO = "audio" STRUCTURED_OUTPUT = "structured-output" + POLLING = "polling" @docs( diff --git a/src/dify_plugin/interfaces/model/large_language_model.py b/src/dify_plugin/interfaces/model/large_language_model.py index 9f6d52ea..70a925a6 100644 --- a/src/dify_plugin/interfaces/model/large_language_model.py +++ b/src/dify_plugin/interfaces/model/large_language_model.py @@ -1,12 +1,15 @@ +import inspect import logging import re import time from abc import abstractmethod from collections.abc import Generator, Mapping +from typing import Any from pydantic import ConfigDict from dify_plugin.entities.model import ( + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -15,6 +18,7 @@ ) from dify_plugin.entities.model.llm import ( LLMMode, + LLMPollingResult, LLMResult, LLMResultChunk, LLMResultChunkDelta, @@ -74,6 +78,51 @@ def _invoke( """ raise NotImplementedError + def _start_polling( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = False, + user: str | None = None, + *, + workflow_run_id: str, + node_id: str, + json_schema: dict[str, Any] | None = None, + ) -> LLMPollingResult: + """Start a polling-based large language model invocation.""" + del ( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + workflow_run_id, + node_id, + json_schema, + ) + raise NotImplementedError + + def _check_polling( + self, + model: str, + credentials: dict, + plugin_state: dict[str, Any], + user: str | None = None, + *, + workflow_run_id: str, + node_id: str, + ) -> LLMPollingResult: + """Check a polling-based large language model invocation.""" + del model, credentials, plugin_state, user, workflow_run_id, node_id + raise NotImplementedError + @abstractmethod def get_num_tokens( self, @@ -136,6 +185,30 @@ def get_model_mode(self, model: str, credentials: Mapping | None = None) -> LLMM return mode + def supports_polling(self, model: str, credentials: Mapping | None = None) -> bool: + model_schema = self.get_model_schema(model, credentials) + has_feature = bool( + model_schema + and model_schema.features + and ModelFeature.POLLING in model_schema.features + ) + base_start_polling = inspect.getattr_static( + LargeLanguageModel, + "_start_polling", + ) + base_check_polling = inspect.getattr_static( + LargeLanguageModel, + "_check_polling", + ) + start_polling = inspect.getattr_static(type(self), "_start_polling") + check_polling = inspect.getattr_static(type(self), "_check_polling") + has_methods = ( + start_polling is not base_start_polling + and check_polling is not base_check_polling + ) + + return has_feature and has_methods + def _calc_response_usage( self, model: str, @@ -685,6 +758,81 @@ def _wrap_thinking_by_reasoning_content( # For executor use only # ############################################################ + def start_polling( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = False, + user: str | None = None, + json_schema: dict[str, Any] | None = None, + *, + workflow_run_id: str, + node_id: str, + ) -> LLMPollingResult: + """Start a polling-based large language model invocation.""" + if not self.supports_polling(model, credentials): + msg = f"Model `{model}` does not support polling." + raise NotImplementedError(msg) + + if model_parameters is None: + model_parameters = {} + + model_parameters = self._validate_and_filter_model_parameters( + model, + model_parameters, + credentials, + ) + + with self.timing_context(): + try: + return self._start_polling( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + workflow_run_id=workflow_run_id, + node_id=node_id, + json_schema=json_schema, + ) + except Exception as e: + raise self._transform_invoke_error(e) from e + + def check_polling( + self, + model: str, + credentials: dict, + plugin_state: dict[str, Any], + user: str | None = None, + *, + workflow_run_id: str, + node_id: str, + ) -> LLMPollingResult: + """Check a polling-based large language model invocation.""" + if not self.supports_polling(model, credentials): + msg = f"Model `{model}` does not support polling." + raise NotImplementedError(msg) + + with self.timing_context(): + try: + return self._check_polling( + model=model, + credentials=credentials, + plugin_state=plugin_state, + user=user, + workflow_run_id=workflow_run_id, + node_id=node_id, + ) + except Exception as e: + raise self._transform_invoke_error(e) from e + def invoke( self, model: str, diff --git a/src/dify_plugin/plugin.py b/src/dify_plugin/plugin.py index 74e5d798..2fda3307 100644 --- a/src/dify_plugin/plugin.py +++ b/src/dify_plugin/plugin.py @@ -277,6 +277,22 @@ def _register_request_routes(self) -> None: ), ) + self.register_route( + self.plugin_executer.start_llm_polling, + lambda data: ( + data.get("type") == PluginInvokeType.Model.value + and data.get("action") == ModelActions.StartPolling.value + ), + ) + + self.register_route( + self.plugin_executer.check_llm_polling, + lambda data: ( + data.get("type") == PluginInvokeType.Model.value + and data.get("action") == ModelActions.CheckPolling.value + ), + ) + self.register_route( self.plugin_executer.get_llm_num_tokens, lambda data: ( diff --git a/tests/test_model_polling.py b/tests/test_model_polling.py new file mode 100644 index 00000000..31928d1e --- /dev/null +++ b/tests/test_model_polling.py @@ -0,0 +1,325 @@ +from collections.abc import Generator, Mapping +from typing import Any + +import pytest + +from dify_plugin.config.config import DifyPluginEnv +from dify_plugin.core.entities.plugin.request import ( + ModelActions, + ModelCheckPollingRequest, + ModelStartPollingRequest, +) +from dify_plugin.core.plugin_executor import PluginExecutor +from dify_plugin.core.runtime import Session +from dify_plugin.entities import I18nObject +from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelFeature, ModelType +from dify_plugin.entities.model.llm import ( + LLMPollingResult, + LLMPollingStatus, + LLMResult, + LLMResultChunk, + LLMUsage, +) +from dify_plugin.entities.model.message import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + UserPromptMessage, +) +from dify_plugin.errors.model import InvokeError +from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel + + +class ModelRegistration: + def __init__(self, model_instance: LargeLanguageModel) -> None: + self.model_instance = model_instance + self.provider: str | None = None + self.model_type: ModelType | None = None + + def get_model_instance( + self, + provider: str, + model_type: ModelType, + ) -> LargeLanguageModel: + self.provider = provider + self.model_type = model_type + return self.model_instance + + +class PollingLLM(LargeLanguageModel): + model_type = ModelType.LLM + + def __init__(self) -> None: + super().__init__( + model_schemas=[ + AIModelEntity( + model="llm", + label=I18nObject(en_us="llm"), + model_type=ModelType.LLM, + features=[ModelFeature.POLLING], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + parameter_rules=[], + ), + ], + ) + self.start_call: dict[str, Any] | None = None + self.check_call: dict[str, Any] | None = None + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + del model, credentials + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return {} + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + del ( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) + return _llm_result("done") + + def _start_polling( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = False, + user: str | None = None, + *, + workflow_run_id: str, + node_id: str, + json_schema: dict[str, Any] | None = None, + ) -> LLMPollingResult: + self.start_call = { + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + "workflow_run_id": workflow_run_id, + "node_id": node_id, + "json_schema": json_schema, + } + return LLMPollingResult( + status=LLMPollingStatus.RUNNING, + plugin_state={"job_id": "job-1"}, + next_check_after_seconds=15, + expires_after_seconds=1800, + max_attempts=60, + ) + + def _check_polling( + self, + model: str, + credentials: dict, + plugin_state: dict[str, Any], + user: str | None = None, + *, + workflow_run_id: str, + node_id: str, + ) -> LLMPollingResult: + self.check_call = { + "model": model, + "credentials": credentials, + "plugin_state": plugin_state, + "user": user, + "workflow_run_id": workflow_run_id, + "node_id": node_id, + } + return LLMPollingResult( + status=LLMPollingStatus.SUCCEEDED, + result=_llm_result("done"), + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: + del model, credentials, prompt_messages, tools + return 0 + + +class NonPollingLLM(PollingLLM): + def __init__(self) -> None: + super().__init__() + self.model_schemas[0].features = [] + + +def test_polling_requests_parse_daemon_payloads() -> None: + start_request = ModelStartPollingRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.LLM, + model="llm", + credentials={"api_key": "key"}, + prompt_messages=[{"role": "user", "content": "hello"}], + model_parameters={}, + stop=[], + tools=[], + json_schema={"type": "object"}, + workflow_run_id="wr-1", + node_id="node-1", + ) + assert start_request.action == ModelActions.StartPolling + assert start_request.stream is False + assert isinstance(start_request.prompt_messages[0], UserPromptMessage) + assert start_request.json_schema == {"type": "object"} + + check_request = ModelCheckPollingRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.LLM, + model="llm", + credentials={"api_key": "key"}, + workflow_run_id="wr-1", + node_id="node-1", + plugin_state={"job_id": "job-1"}, + ) + assert check_request.action == ModelActions.CheckPolling + assert check_request.plugin_state == {"job_id": "job-1"} + + +def test_executor_starts_llm_polling() -> None: + model = PollingLLM() + executor = PluginExecutor(DifyPluginEnv(), ModelRegistration(model)) + + response = executor.start_llm_polling( + Session.empty_session(), + ModelStartPollingRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.LLM, + model="llm", + credentials={"api_key": "key"}, + prompt_messages=[UserPromptMessage(content="hello")], + model_parameters={"temperature": 0.2}, + stop=[], + tools=[], + json_schema={"type": "object"}, + workflow_run_id="wr-1", + node_id="node-1", + ), + ) + + assert isinstance(response, LLMPollingResult) + assert response.status == LLMPollingStatus.RUNNING + assert response.plugin_state == {"job_id": "job-1"} + assert response.next_check_after_seconds == 15 + assert response.expires_after_seconds == 1800 + assert response.max_attempts == 60 + assert model.start_call is not None + assert model.supports_polling("llm", {"api_key": "key"}) + assert model.start_call["workflow_run_id"] == "wr-1" + assert model.start_call["node_id"] == "node-1" + assert model.start_call["json_schema"] == {"type": "object"} + assert model.start_call["model_parameters"] == {} + + +def test_executor_checks_llm_polling() -> None: + model = PollingLLM() + executor = PluginExecutor(DifyPluginEnv(), ModelRegistration(model)) + + response = executor.check_llm_polling( + Session.empty_session(), + ModelCheckPollingRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.LLM, + model="llm", + credentials={"api_key": "key"}, + workflow_run_id="wr-1", + node_id="node-1", + plugin_state={"job_id": "job-1"}, + ), + ) + + assert isinstance(response, LLMPollingResult) + assert response.status == LLMPollingStatus.SUCCEEDED + assert response.result is not None + assert response.result.message.content == "done" + assert model.check_call is not None + assert model.check_call["plugin_state"] == {"job_id": "job-1"} + assert model.check_call["workflow_run_id"] == "wr-1" + assert model.check_call["node_id"] == "node-1" + + +def test_executor_rejects_llm_without_polling_feature() -> None: + model = NonPollingLLM() + executor = PluginExecutor(DifyPluginEnv(), ModelRegistration(model)) + + with pytest.raises(ValueError, match="does not support polling"): + executor.start_llm_polling( + Session.empty_session(), + ModelStartPollingRequest( + user_id="user-1", + provider="provider", + model_type=ModelType.LLM, + model="llm", + credentials={"api_key": "key"}, + prompt_messages=[UserPromptMessage(content="hello")], + model_parameters={}, + stop=[], + tools=[], + workflow_run_id="wr-1", + node_id="node-1", + ), + ) + + +def test_polling_result_validates_state_payloads() -> None: + with pytest.raises(ValueError, match="plugin_state is required"): + LLMPollingResult(status=LLMPollingStatus.RUNNING) + + with pytest.raises(ValueError, match="result is required"): + LLMPollingResult(status=LLMPollingStatus.SUCCEEDED) + + with pytest.raises(ValueError, match="error is required"): + LLMPollingResult(status=LLMPollingStatus.FAILED) + + +@pytest.mark.parametrize( + "field_name", + ["next_check_after_seconds", "expires_after_seconds", "max_attempts"], +) +def test_polling_result_rejects_non_positive_limits(field_name: str) -> None: + with pytest.raises(ValueError, match=f"{field_name} must be greater than 0"): + LLMPollingResult( + status=LLMPollingStatus.RUNNING, + plugin_state={"job_id": "job-1"}, + **{field_name: 0}, + ) + + +def _llm_result(content: str) -> LLMResult: + return LLMResult( + model="llm", + message=AssistantPromptMessage(content=content), + usage=LLMUsage.empty_usage(), + ) From cb8cfb60360dec6e8934ba11a06cc9615d5beae2 Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Tue, 19 May 2026 18:30:57 +0800 Subject: [PATCH 2/6] chore: bump SDK version to 0.9.0 --- README.md | 14 ++++++++++++++ pyproject.toml | 2 +- uv.lock | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d2f3e944..917f5e5b 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,19 @@ just test # Run all tests just build # Build source and wheel distributions ``` +## LLM Polling Support + +SDK 0.9.0 adds protocol hooks for polling-based LLM providers. A model can now +declare the `polling` feature and implement polling hooks, allowing plugins to +submit long-running provider jobs and return later checks through a short +request/response flow. + +Polling results use three states: + +- `running` returns plugin-owned state for the next check. +- `succeeded` returns the final LLM result. +- `failed` returns a terminal error. + ## Version Management This SDK follows Semantic Versioning (a.b.c): @@ -73,3 +86,4 @@ For the manifest specification, we've introduced two versioning fields: | 1.10.0 | 0.6.0 | Support Trigger functionality for plugins | | 1.11.0 | 0.7.0 | Support Multimodal Reranking / Embeddings | | 1.14.0 | 0.8.1 | Dependency and project structure cleanup | +| 1.14.0 | 0.9.0 | Support polling-based LLM plugin protocol hooks | diff --git a/pyproject.toml b/pyproject.toml index 5f3a5864..be91cfb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'dify_plugin' -version = '0.8.1' +version = '0.9.0' description = 'Dify Plugin SDK' authors = [{ name = 'langgenius', email = 'hello@dify.ai' }] dependencies = [ diff --git a/uv.lock b/uv.lock index 3cd7f721..914b1641 100644 --- a/uv.lock +++ b/uv.lock @@ -245,7 +245,7 @@ wheels = [ [[package]] name = "dify-plugin" -version = "0.8.1" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "dpkt" }, From 5e320addffbbbf7cbbe0223d8ca373ac3b526352 Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Tue, 19 May 2026 18:33:28 +0800 Subject: [PATCH 3/6] docs: clarify LLM polling support wording --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 917f5e5b..125d32a0 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,9 @@ just build # Build source and wheel distributions ## LLM Polling Support -SDK 0.9.0 adds protocol hooks for polling-based LLM providers. A model can now -declare the `polling` feature and implement polling hooks, allowing plugins to -submit long-running provider jobs and return later checks through a short +SDK 0.9.0 adds support for polling-based LLM invocation. A model can now +declare the `polling` feature and implement polling methods, allowing plugins +to submit long-running provider jobs and return later checks through a short request/response flow. Polling results use three states: @@ -86,4 +86,4 @@ For the manifest specification, we've introduced two versioning fields: | 1.10.0 | 0.6.0 | Support Trigger functionality for plugins | | 1.11.0 | 0.7.0 | Support Multimodal Reranking / Embeddings | | 1.14.0 | 0.8.1 | Dependency and project structure cleanup | -| 1.14.0 | 0.9.0 | Support polling-based LLM plugin protocol hooks | +| 1.14.0 | 0.9.0 | Support polling-based LLM plugin invocations | From bf4c952773f8bfb95a69b0d57726db48f68042c5 Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Tue, 19 May 2026 18:36:53 +0800 Subject: [PATCH 4/6] fix: compare polling status by value --- src/dify_plugin/entities/model/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dify_plugin/entities/model/llm.py b/src/dify_plugin/entities/model/llm.py index c4cd47ab..2e94beb6 100644 --- a/src/dify_plugin/entities/model/llm.py +++ b/src/dify_plugin/entities/model/llm.py @@ -194,15 +194,15 @@ class LLMPollingResult(BaseModel): @model_validator(mode="after") def validate_status_payload(self) -> "LLMPollingResult": - if self.status is LLMPollingStatus.RUNNING and self.plugin_state is None: + if self.status == LLMPollingStatus.RUNNING and self.plugin_state is None: msg = "plugin_state is required when polling status is running." raise ValueError(msg) - if self.status is LLMPollingStatus.SUCCEEDED and self.result is None: + if self.status == LLMPollingStatus.SUCCEEDED and self.result is None: msg = "result is required when polling status is succeeded." raise ValueError(msg) - if self.status is LLMPollingStatus.FAILED and not self.error: + if self.status == LLMPollingStatus.FAILED and not self.error: msg = "error is required when polling status is failed." raise ValueError(msg) From 7a011f49f3810a1b0b06c8db9367b317590e2cc3 Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Wed, 20 May 2026 12:09:00 +0800 Subject: [PATCH 5/6] fix: address llm polling review feedback --- .../core/entities/plugin/request.py | 10 +- src/dify_plugin/core/plugin_executor.py | 80 +++--- src/dify_plugin/entities/model/llm.py | 29 +- .../interfaces/model/large_language_model.py | 16 +- tests/test_model_polling.py | 271 +++++++++++------- 5 files changed, 228 insertions(+), 178 deletions(-) diff --git a/src/dify_plugin/core/entities/plugin/request.py b/src/dify_plugin/core/entities/plugin/request.py index b7426452..01ff65d2 100644 --- a/src/dify_plugin/core/entities/plugin/request.py +++ b/src/dify_plugin/core/entities/plugin/request.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any +from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator from dify_plugin.entities.datasource import ( GetOnlineDocumentPageContentRequest, @@ -195,7 +195,7 @@ class ModelInvokeLLMRequest(PluginAccessModelRequest, PromptMessageMixin): model_parameters: dict[str, Any] stop: list[str] | None tools: list[PromptMessageTool] | None - json_schema: dict[str, Any] | None = None + json_schema: dict[str, JsonValue] | None = None stream: bool = True model_config = ConfigDict(protected_namespaces=()) @@ -203,7 +203,7 @@ class ModelInvokeLLMRequest(PluginAccessModelRequest, PromptMessageMixin): class ModelStartPollingRequest(ModelInvokeLLMRequest): action: ModelActions = ModelActions.StartPolling - stream: bool = False + stream: Literal[False] = False workflow_run_id: str node_id: str @@ -214,7 +214,7 @@ class ModelCheckPollingRequest(PluginAccessModelRequest): workflow_run_id: str node_id: str - plugin_state: dict[str, Any] + plugin_state: dict[str, JsonValue] class ModelGetLLMNumTokens(PluginAccessModelRequest, PromptMessageMixin): diff --git a/src/dify_plugin/core/plugin_executor.py b/src/dify_plugin/core/plugin_executor.py index 6ae39f19..55594a2b 100644 --- a/src/dify_plugin/core/plugin_executor.py +++ b/src/dify_plugin/core/plugin_executor.py @@ -273,30 +273,31 @@ def start_llm_polling( data.provider, data.model_type, ) - if isinstance(model_instance, LargeLanguageModel): - if not model_instance.supports_polling(data.model, data.credentials): - msg = ( - f"Model `{data.model}` for provider `{data.provider}` " - "does not support polling" - ) - raise ValueError(msg) + if not isinstance(model_instance, LargeLanguageModel): + msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" + raise TypeError( + msg, + ) - return model_instance.start_polling( - model=data.model, - credentials=data.credentials, - prompt_messages=data.prompt_messages, - model_parameters=data.model_parameters, - tools=data.tools, - stop=data.stop, - stream=data.stream, - user=data.user_id, - json_schema=data.json_schema, - workflow_run_id=data.workflow_run_id, - node_id=data.node_id, + if not model_instance.supports_polling(data.model, data.credentials): + msg = ( + f"Model `{data.model}` for provider `{data.provider}` " + "does not support polling" ) - msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" - raise ValueError( - msg, + raise ValueError(msg) + + return model_instance.start_polling( + model=data.model, + credentials=data.credentials, + prompt_messages=data.prompt_messages, + model_parameters=data.model_parameters, + tools=data.tools, + stop=data.stop, + stream=data.stream, + user=data.user_id, + json_schema=data.json_schema, + workflow_run_id=data.workflow_run_id, + node_id=data.node_id, ) def check_llm_polling( @@ -309,25 +310,26 @@ def check_llm_polling( data.provider, data.model_type, ) - if isinstance(model_instance, LargeLanguageModel): - if not model_instance.supports_polling(data.model, data.credentials): - msg = ( - f"Model `{data.model}` for provider `{data.provider}` " - "does not support polling" - ) - raise ValueError(msg) + if not isinstance(model_instance, LargeLanguageModel): + msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" + raise TypeError( + msg, + ) - return model_instance.check_polling( - model=data.model, - credentials=data.credentials, - plugin_state=data.plugin_state, - user=data.user_id, - workflow_run_id=data.workflow_run_id, - node_id=data.node_id, + if not model_instance.supports_polling(data.model, data.credentials): + msg = ( + f"Model `{data.model}` for provider `{data.provider}` " + "does not support polling" ) - msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" - raise ValueError( - msg, + raise ValueError(msg) + + return model_instance.check_polling( + model=data.model, + credentials=data.credentials, + plugin_state=data.plugin_state, + user=data.user_id, + workflow_run_id=data.workflow_run_id, + node_id=data.node_id, ) def get_llm_num_tokens( diff --git a/src/dify_plugin/entities/model/llm.py b/src/dify_plugin/entities/model/llm.py index 2e94beb6..e84a8247 100644 --- a/src/dify_plugin/entities/model/llm.py +++ b/src/dify_plugin/entities/model/llm.py @@ -1,9 +1,16 @@ from collections.abc import Mapping from decimal import Decimal from enum import Enum, StrEnum -from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + JsonValue, + PositiveInt, + field_validator, + model_validator, +) from dify_plugin.entities.model import BaseModelConfig, ModelType, ModelUsage, PriceInfo from dify_plugin.entities.model.message import ( @@ -185,12 +192,12 @@ class LLMPollingResult(BaseModel): """Model class for llm polling result.""" status: LLMPollingStatus - plugin_state: dict[str, Any] | None = None + plugin_state: dict[str, JsonValue] | None = None result: LLMResult | LLMResultWithStructuredOutput | None = None error: str | None = None - next_check_after_seconds: int | None = None - expires_after_seconds: int | None = None - max_attempts: int | None = None + next_check_after_seconds: PositiveInt | None = None + expires_after_seconds: PositiveInt | None = None + max_attempts: PositiveInt | None = None @model_validator(mode="after") def validate_status_payload(self) -> "LLMPollingResult": @@ -206,16 +213,6 @@ def validate_status_payload(self) -> "LLMPollingResult": msg = "error is required when polling status is failed." raise ValueError(msg) - for field_name in ( - "next_check_after_seconds", - "expires_after_seconds", - "max_attempts", - ): - value = getattr(self, field_name) - if value is not None and value <= 0: - msg = f"{field_name} must be greater than 0." - raise ValueError(msg) - return self diff --git a/src/dify_plugin/interfaces/model/large_language_model.py b/src/dify_plugin/interfaces/model/large_language_model.py index 70a925a6..536b25fb 100644 --- a/src/dify_plugin/interfaces/model/large_language_model.py +++ b/src/dify_plugin/interfaces/model/large_language_model.py @@ -4,9 +4,9 @@ import time from abc import abstractmethod from collections.abc import Generator, Mapping -from typing import Any +from typing import Literal -from pydantic import ConfigDict +from pydantic import ConfigDict, JsonValue from dify_plugin.entities.model import ( ModelFeature, @@ -86,12 +86,12 @@ def _start_polling( model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = False, + stream: Literal[False] = False, user: str | None = None, *, workflow_run_id: str, node_id: str, - json_schema: dict[str, Any] | None = None, + json_schema: dict[str, JsonValue] | None = None, ) -> LLMPollingResult: """Start a polling-based large language model invocation.""" del ( @@ -113,7 +113,7 @@ def _check_polling( self, model: str, credentials: dict, - plugin_state: dict[str, Any], + plugin_state: dict[str, JsonValue], user: str | None = None, *, workflow_run_id: str, @@ -766,9 +766,9 @@ def start_polling( model_parameters: dict | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = False, + stream: Literal[False] = False, user: str | None = None, - json_schema: dict[str, Any] | None = None, + json_schema: dict[str, JsonValue] | None = None, *, workflow_run_id: str, node_id: str, @@ -809,7 +809,7 @@ def check_polling( self, model: str, credentials: dict, - plugin_state: dict[str, Any], + plugin_state: dict[str, JsonValue], user: str | None = None, *, workflow_run_id: str, diff --git a/tests/test_model_polling.py b/tests/test_model_polling.py index 31928d1e..3e1cfc5e 100644 --- a/tests/test_model_polling.py +++ b/tests/test_model_polling.py @@ -1,7 +1,9 @@ from collections.abc import Generator, Mapping -from typing import Any +from dataclasses import dataclass +from typing import Any, Literal import pytest +from pydantic import JsonValue from dify_plugin.config.config import DifyPluginEnv from dify_plugin.core.entities.plugin.request import ( @@ -30,6 +32,105 @@ from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel +@dataclass(frozen=True) +class PollingScenario: + user_id: str = "user-1" + provider: str = "provider" + model: str = "llm" + api_key: str = "key" + workflow_run_id: str = "wr-1" + node_id: str = "node-1" + job_id: str = "job-1" + prompt_content: str = "hello" + result_content: str = "done" + next_check_after_seconds: int = 15 + expires_after_seconds: int = 1800 + max_attempts: int = 60 + + @property + def credentials(self) -> dict[str, str]: + return {"api_key": self.api_key} + + @property + def json_schema(self) -> dict[str, JsonValue]: + return {"type": "object"} + + @property + def plugin_state(self) -> dict[str, JsonValue]: + return {"job_id": self.job_id} + + @property + def daemon_prompt_messages(self) -> list[dict[str, str]]: + return [{"role": "user", "content": self.prompt_content}] + + @property + def prompt_messages(self) -> list[UserPromptMessage]: + return [UserPromptMessage(content=self.prompt_content)] + + def model_entity(self) -> AIModelEntity: + return AIModelEntity( + model=self.model, + label=I18nObject(en_us=self.model), + model_type=ModelType.LLM, + features=[ModelFeature.POLLING], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + parameter_rules=[], + ) + + def start_request( + self, + *, + prompt_messages: object | None = None, + model_parameters: dict[str, object] | None = None, + json_schema: dict[str, JsonValue] | None = None, + stream: bool | None = None, + ) -> ModelStartPollingRequest: + data: dict[str, object] = { + "user_id": self.user_id, + "provider": self.provider, + "model_type": ModelType.LLM, + "model": self.model, + "credentials": self.credentials, + "prompt_messages": prompt_messages or self.prompt_messages, + "model_parameters": model_parameters or {}, + "stop": [], + "tools": [], + "workflow_run_id": self.workflow_run_id, + "node_id": self.node_id, + } + if json_schema is not None: + data["json_schema"] = json_schema + if stream is not None: + data["stream"] = stream + + return ModelStartPollingRequest(**data) + + def check_request( + self, + *, + plugin_state: dict[str, JsonValue] | None = None, + ) -> ModelCheckPollingRequest: + data: dict[str, object] = { + "user_id": self.user_id, + "provider": self.provider, + "model_type": ModelType.LLM, + "model": self.model, + "credentials": self.credentials, + "workflow_run_id": self.workflow_run_id, + "node_id": self.node_id, + "plugin_state": plugin_state or self.plugin_state, + } + return ModelCheckPollingRequest(**data) + + def llm_result(self, content: str | None = None) -> LLMResult: + return LLMResult( + model=self.model, + message=AssistantPromptMessage(content=content or self.result_content), + usage=LLMUsage.empty_usage(), + ) + + class ModelRegistration: def __init__(self, model_instance: LargeLanguageModel) -> None: self.model_instance = model_instance @@ -49,19 +150,10 @@ def get_model_instance( class PollingLLM(LargeLanguageModel): model_type = ModelType.LLM - def __init__(self) -> None: + def __init__(self, scenario: PollingScenario | None = None) -> None: + self.scenario = scenario or PollingScenario() super().__init__( - model_schemas=[ - AIModelEntity( - model="llm", - label=I18nObject(en_us="llm"), - model_type=ModelType.LLM, - features=[ModelFeature.POLLING], - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties={}, - parameter_rules=[], - ), - ], + model_schemas=[self.scenario.model_entity()], ) self.start_call: dict[str, Any] | None = None self.check_call: dict[str, Any] | None = None @@ -94,7 +186,7 @@ def _invoke( stream, user, ) - return _llm_result("done") + return self.scenario.llm_result() def _start_polling( self, @@ -104,12 +196,12 @@ def _start_polling( model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = False, + stream: Literal[False] = False, user: str | None = None, *, workflow_run_id: str, node_id: str, - json_schema: dict[str, Any] | None = None, + json_schema: dict[str, JsonValue] | None = None, ) -> LLMPollingResult: self.start_call = { "model": model, @@ -126,17 +218,17 @@ def _start_polling( } return LLMPollingResult( status=LLMPollingStatus.RUNNING, - plugin_state={"job_id": "job-1"}, - next_check_after_seconds=15, - expires_after_seconds=1800, - max_attempts=60, + plugin_state=self.scenario.plugin_state, + next_check_after_seconds=self.scenario.next_check_after_seconds, + expires_after_seconds=self.scenario.expires_after_seconds, + max_attempts=self.scenario.max_attempts, ) def _check_polling( self, model: str, credentials: dict, - plugin_state: dict[str, Any], + plugin_state: dict[str, JsonValue], user: str | None = None, *, workflow_run_id: str, @@ -152,7 +244,7 @@ def _check_polling( } return LLMPollingResult( status=LLMPollingStatus.SUCCEEDED, - result=_llm_result("done"), + result=self.scenario.llm_result(), ) def get_num_tokens( @@ -167,129 +259,94 @@ def get_num_tokens( class NonPollingLLM(PollingLLM): - def __init__(self) -> None: - super().__init__() + def __init__(self, scenario: PollingScenario | None = None) -> None: + super().__init__(scenario) self.model_schemas[0].features = [] def test_polling_requests_parse_daemon_payloads() -> None: - start_request = ModelStartPollingRequest( - user_id="user-1", - provider="provider", - model_type=ModelType.LLM, - model="llm", - credentials={"api_key": "key"}, - prompt_messages=[{"role": "user", "content": "hello"}], - model_parameters={}, - stop=[], - tools=[], - json_schema={"type": "object"}, - workflow_run_id="wr-1", - node_id="node-1", + scenario = PollingScenario() + + start_request = scenario.start_request( + prompt_messages=scenario.daemon_prompt_messages, + json_schema=scenario.json_schema, ) assert start_request.action == ModelActions.StartPolling assert start_request.stream is False assert isinstance(start_request.prompt_messages[0], UserPromptMessage) - assert start_request.json_schema == {"type": "object"} - - check_request = ModelCheckPollingRequest( - user_id="user-1", - provider="provider", - model_type=ModelType.LLM, - model="llm", - credentials={"api_key": "key"}, - workflow_run_id="wr-1", - node_id="node-1", - plugin_state={"job_id": "job-1"}, - ) + assert start_request.json_schema == scenario.json_schema + + check_request = scenario.check_request() assert check_request.action == ModelActions.CheckPolling - assert check_request.plugin_state == {"job_id": "job-1"} + assert check_request.plugin_state == scenario.plugin_state + + +def test_start_polling_request_rejects_streaming() -> None: + scenario = PollingScenario() + + with pytest.raises(ValueError, match="Input should be False"): + scenario.start_request( + prompt_messages=scenario.daemon_prompt_messages, + stream=True, + ) def test_executor_starts_llm_polling() -> None: - model = PollingLLM() + scenario = PollingScenario() + model = PollingLLM(scenario) executor = PluginExecutor(DifyPluginEnv(), ModelRegistration(model)) response = executor.start_llm_polling( Session.empty_session(), - ModelStartPollingRequest( - user_id="user-1", - provider="provider", - model_type=ModelType.LLM, - model="llm", - credentials={"api_key": "key"}, - prompt_messages=[UserPromptMessage(content="hello")], + scenario.start_request( model_parameters={"temperature": 0.2}, - stop=[], - tools=[], - json_schema={"type": "object"}, - workflow_run_id="wr-1", - node_id="node-1", + json_schema=scenario.json_schema, ), ) assert isinstance(response, LLMPollingResult) assert response.status == LLMPollingStatus.RUNNING - assert response.plugin_state == {"job_id": "job-1"} - assert response.next_check_after_seconds == 15 - assert response.expires_after_seconds == 1800 - assert response.max_attempts == 60 + assert response.plugin_state == scenario.plugin_state + assert response.next_check_after_seconds == scenario.next_check_after_seconds + assert response.expires_after_seconds == scenario.expires_after_seconds + assert response.max_attempts == scenario.max_attempts assert model.start_call is not None - assert model.supports_polling("llm", {"api_key": "key"}) - assert model.start_call["workflow_run_id"] == "wr-1" - assert model.start_call["node_id"] == "node-1" - assert model.start_call["json_schema"] == {"type": "object"} + assert model.supports_polling(scenario.model, scenario.credentials) + assert model.start_call["workflow_run_id"] == scenario.workflow_run_id + assert model.start_call["node_id"] == scenario.node_id + assert model.start_call["json_schema"] == scenario.json_schema assert model.start_call["model_parameters"] == {} def test_executor_checks_llm_polling() -> None: - model = PollingLLM() + scenario = PollingScenario() + model = PollingLLM(scenario) executor = PluginExecutor(DifyPluginEnv(), ModelRegistration(model)) response = executor.check_llm_polling( Session.empty_session(), - ModelCheckPollingRequest( - user_id="user-1", - provider="provider", - model_type=ModelType.LLM, - model="llm", - credentials={"api_key": "key"}, - workflow_run_id="wr-1", - node_id="node-1", - plugin_state={"job_id": "job-1"}, - ), + scenario.check_request(), ) assert isinstance(response, LLMPollingResult) assert response.status == LLMPollingStatus.SUCCEEDED assert response.result is not None - assert response.result.message.content == "done" + assert response.result.message.content == scenario.result_content assert model.check_call is not None - assert model.check_call["plugin_state"] == {"job_id": "job-1"} - assert model.check_call["workflow_run_id"] == "wr-1" - assert model.check_call["node_id"] == "node-1" + assert model.check_call["plugin_state"] == scenario.plugin_state + assert model.check_call["workflow_run_id"] == scenario.workflow_run_id + assert model.check_call["node_id"] == scenario.node_id def test_executor_rejects_llm_without_polling_feature() -> None: - model = NonPollingLLM() + scenario = PollingScenario() + model = NonPollingLLM(scenario) executor = PluginExecutor(DifyPluginEnv(), ModelRegistration(model)) with pytest.raises(ValueError, match="does not support polling"): executor.start_llm_polling( Session.empty_session(), - ModelStartPollingRequest( - user_id="user-1", - provider="provider", - model_type=ModelType.LLM, - model="llm", - credentials={"api_key": "key"}, - prompt_messages=[UserPromptMessage(content="hello")], - model_parameters={}, - stop=[], - tools=[], - workflow_run_id="wr-1", - node_id="node-1", - ), + scenario.start_request(), ) @@ -309,17 +366,11 @@ def test_polling_result_validates_state_payloads() -> None: ["next_check_after_seconds", "expires_after_seconds", "max_attempts"], ) def test_polling_result_rejects_non_positive_limits(field_name: str) -> None: - with pytest.raises(ValueError, match=f"{field_name} must be greater than 0"): + scenario = PollingScenario() + + with pytest.raises(ValueError, match="Input should be greater than 0"): LLMPollingResult( status=LLMPollingStatus.RUNNING, - plugin_state={"job_id": "job-1"}, + plugin_state=scenario.plugin_state, **{field_name: 0}, ) - - -def _llm_result(content: str) -> LLMResult: - return LLMResult( - model="llm", - message=AssistantPromptMessage(content=content), - usage=LLMUsage.empty_usage(), - ) From 58464a246727cdc95c8aea8c4a9d23e6fb8ae00e Mon Sep 17 00:00:00 2001 From: WH-2099 Date: Wed, 20 May 2026 12:13:00 +0800 Subject: [PATCH 6/6] docs: update polling compatibility version --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 125d32a0..8bec5eb1 100644 --- a/README.md +++ b/README.md @@ -86,4 +86,4 @@ For the manifest specification, we've introduced two versioning fields: | 1.10.0 | 0.6.0 | Support Trigger functionality for plugins | | 1.11.0 | 0.7.0 | Support Multimodal Reranking / Embeddings | | 1.14.0 | 0.8.1 | Dependency and project structure cleanup | -| 1.14.0 | 0.9.0 | Support polling-based LLM plugin invocations | +| 1.14.2 | 0.9.0 | Support polling-based LLM plugin invocations |