Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dify_plugin/core/entities/plugin/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class ModelCheckPollingRequest(PluginAccessModelRequest):

workflow_run_id: str
node_id: str
plugin_state: dict[str, JsonValue]
plugin_state: dict[str, JsonValue] = Field(min_length=1)


class ModelGetLLMNumTokens(PluginAccessModelRequest, PromptMessageMixin):
Expand Down
2 changes: 1 addition & 1 deletion src/dify_plugin/entities/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class LLMPollingResult(BaseModel):

@model_validator(mode="after")
def validate_status_payload(self) -> "LLMPollingResult":
if self.status == LLMPollingStatus.RUNNING and self.plugin_state is None:
if self.status == LLMPollingStatus.RUNNING and not self.plugin_state:
msg = "plugin_state is required when polling status is running."
raise ValueError(msg)

Expand Down
32 changes: 26 additions & 6 deletions tests/test_model_polling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Literal

import pytest
from pydantic import JsonValue
from pydantic import JsonValue, ValidationError

from dify_plugin.config.config import DifyPluginEnv
from dify_plugin.core.entities.plugin.request import (
Expand Down Expand Up @@ -284,13 +284,30 @@ def test_polling_requests_parse_daemon_payloads() -> None:
def test_start_polling_request_rejects_streaming() -> None:
scenario = PollingScenario()

with pytest.raises(ValueError, match="Input should be False"):
with pytest.raises(ValidationError, match="Input should be False"):
scenario.start_request(
prompt_messages=scenario.daemon_prompt_messages,
stream=True,
)


def test_check_polling_request_rejects_empty_plugin_state() -> None:
scenario = PollingScenario()
data: dict[str, object] = {
"user_id": scenario.user_id,
"provider": scenario.provider,
"model_type": ModelType.LLM,
"model": scenario.model,
"credentials": scenario.credentials,
"workflow_run_id": scenario.workflow_run_id,
"node_id": scenario.node_id,
"plugin_state": {},
}

with pytest.raises(ValidationError, match="at least 1 item"):
ModelCheckPollingRequest(**data)


def test_executor_starts_llm_polling() -> None:
scenario = PollingScenario()
model = PollingLLM(scenario)
Expand Down Expand Up @@ -351,13 +368,16 @@ def test_executor_rejects_llm_without_polling_feature() -> None:


def test_polling_result_validates_state_payloads() -> None:
with pytest.raises(ValueError, match="plugin_state is required"):
with pytest.raises(ValidationError, match="plugin_state is required"):
LLMPollingResult(status=LLMPollingStatus.RUNNING)

with pytest.raises(ValueError, match="result is required"):
with pytest.raises(ValidationError, match="plugin_state is required"):
LLMPollingResult(status=LLMPollingStatus.RUNNING, plugin_state={})

with pytest.raises(ValidationError, match="result is required"):
LLMPollingResult(status=LLMPollingStatus.SUCCEEDED)

with pytest.raises(ValueError, match="error is required"):
with pytest.raises(ValidationError, match="error is required"):
LLMPollingResult(status=LLMPollingStatus.FAILED)


Expand All @@ -368,7 +388,7 @@ def test_polling_result_validates_state_payloads() -> None:
def test_polling_result_rejects_non_positive_limits(field_name: str) -> None:
scenario = PollingScenario()

with pytest.raises(ValueError, match="Input should be greater than 0"):
with pytest.raises(ValidationError, match="Input should be greater than 0"):
LLMPollingResult(
status=LLMPollingStatus.RUNNING,
plugin_state=scenario.plugin_state,
Expand Down