Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 108 additions & 43 deletions tests/tools/paddleocr/test_file_input.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
import base64
import os
import sys
from unittest.mock import MagicMock

import pytest
import yaml

# Mock paddleocr module before any imports
mock_paddleocr = MagicMock()

# Mock public API classes
mock_paddleocr.PaddleOCRClient = MagicMock
mock_paddleocr.OCROptions = lambda **kw: MagicMock()
mock_paddleocr.PPStructureV3Options = lambda **kw: MagicMock()
mock_paddleocr.PaddleOCRVLOptions = lambda **kw: MagicMock()
mock_paddleocr.AuthError = Exception
mock_paddleocr.PaddleOCRAPIError = Exception

# Mock internal modules for backward compatibility
mock_paddleocr._api_client = MagicMock()
mock_paddleocr._api_client.PaddleOCRClient = mock_paddleocr.PaddleOCRClient
mock_paddleocr._api_client.models = MagicMock()
mock_paddleocr._api_client.models.OCROptions = mock_paddleocr.OCROptions
mock_paddleocr._api_client.models.PPStructureV3Options = mock_paddleocr.PPStructureV3Options
mock_paddleocr._api_client.models.PaddleOCRVLOptions = mock_paddleocr.PaddleOCRVLOptions
mock_paddleocr._api_client.errors = MagicMock()
mock_paddleocr._api_client.errors.AuthError = mock_paddleocr.AuthError
mock_paddleocr._api_client.errors.PaddleOCRAPIError = mock_paddleocr.PaddleOCRAPIError

sys.modules["paddleocr"] = mock_paddleocr
sys.modules["paddleocr._api_client"] = mock_paddleocr._api_client
sys.modules["paddleocr._api_client.models"] = mock_paddleocr._api_client.models
sys.modules["paddleocr._api_client.errors"] = mock_paddleocr._api_client.errors

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
PLUGIN_DIR = os.path.join(REPO_ROOT, "tools", "paddleocr")
if PLUGIN_DIR not in sys.path:
Expand Down Expand Up @@ -45,21 +73,29 @@ def test_file_upload_is_base64_encoded():
file_type=FileType.IMAGE,
)

payload, normalized_file_type = normalize_file_input(file, "auto")
input_value, is_temp_file, file_type_code = normalize_file_input(file, "auto")

assert payload == base64.b64encode(b"image-bytes").decode("utf-8")
assert normalized_file_type == 1
# New implementation saves to temp file for SDK
assert os.path.exists(input_value)
assert is_temp_file is True
assert file_type_code == 1
# Clean up
os.unlink(input_value)


def test_pdf_file_upload_infers_file_type():
file = make_file(
b"%PDF-1.7", filename="invoice.pdf", mime_type="application/pdf", extension=".pdf"
)

payload, normalized_file_type = normalize_file_input(file, "auto")
input_value, is_temp_file, file_type_code = normalize_file_input(file, "auto")

assert payload == base64.b64encode(b"%PDF-1.7").decode("utf-8")
assert normalized_file_type == 0
# New implementation saves to temp file for SDK
assert os.path.exists(input_value)
assert is_temp_file is True
assert file_type_code == 0
# Clean up
os.unlink(input_value)


def test_image_file_upload_infers_file_type_from_filename_when_mime_type_missing():
Expand All @@ -71,10 +107,14 @@ def test_image_file_upload_infers_file_type_from_filename_when_mime_type_missing
file_type=FileType.IMAGE,
)

payload, normalized_file_type = normalize_file_input(file, None)
input_value, is_temp_file, file_type_code = normalize_file_input(file, None)

assert payload == base64.b64encode(b"image-bytes").decode("utf-8")
assert normalized_file_type == 1
# New implementation saves to temp file for SDK
assert os.path.exists(input_value)
assert is_temp_file is True
assert file_type_code == 1
# Clean up
os.unlink(input_value)


def test_explicit_file_type_overrides_inference():
Expand All @@ -86,17 +126,22 @@ def test_explicit_file_type_overrides_inference():
file_type=FileType.IMAGE,
)

payload, normalized_file_type = normalize_file_input(file, "pdf")
input_value, is_temp_file, file_type_code = normalize_file_input(file, "pdf")

assert payload == base64.b64encode(b"image-bytes").decode("utf-8")
assert normalized_file_type == 0
# New implementation saves to temp file for SDK
assert os.path.exists(input_value)
assert is_temp_file is True
assert file_type_code == 0
# Clean up
os.unlink(input_value)


def test_legacy_file_string_is_passed_through():
payload, normalized_file_type = normalize_file_input("https://example.com/scan.pdf", "auto")
input_value, is_temp_file, file_type_code = normalize_file_input("https://example.com/scan.pdf", "auto")

assert payload == "https://example.com/scan.pdf"
assert normalized_file_type is None
assert input_value == "https://example.com/scan.pdf"
assert is_temp_file is False
assert file_type_code is None


def test_missing_file_input_raises_clear_error():
Expand All @@ -106,21 +151,44 @@ def test_missing_file_input_raises_clear_error():

def invoke_tool_with_mocked_api(monkeypatch, tool_cls, credentials, parameters):
captured = {}
module_name = tool_cls.__module__.split(".")[-1]

def fake_api_request(api_url, params, access_token):
captured["api_url"] = api_url
captured["params"] = params
captured["access_token"] = access_token
return {
"errorCode": 0,
"result": {
"ocrResults": [{"prunedResult": {"rec_texts": ["hello", "world"]}}],
"layoutParsingResults": [{"markdown": {"text": "# Parsed", "images": {}}}],
},
}

monkeypatch.setattr(f"tools.{module_name}.make_paddleocr_api_request", fake_api_request)

def fake_sdk_call(**kwargs):
captured["kwargs"] = kwargs
# Return mock result - use simple dict instead of SDK classes
if tool_cls == TextRecognitionTool:
return type("OCRResult", (), {"job_id": "test-job", "pages": [
type("OCRPage", (), {"pruned_result": {"rec_texts": ["hello", "world"]}, "ocr_image_url": None})()
]})()
else:
return type("DocParsingResult", (), {"job_id": "test-job", "pages": [
type("DocParsingPage", (), {"markdown_text": "# Parsed", "markdown_images": {}, "output_images": {}})()
]})()

# Mock the entire SDK module and client
fake_client = MagicMock()
fake_client.ocr = fake_sdk_call
fake_client.parse_document = fake_sdk_call

# Mock utils module functions
import tools.utils as utils_module
monkeypatch.setattr(utils_module, "get_sdk_client", lambda *args: fake_client)
monkeypatch.setattr(utils_module, "base64_to_temp_file", lambda *args: "temp_file.png")
monkeypatch.setattr(utils_module, "cleanup_temp_file", lambda *args: None)

# Mock in the specific tool module (they import these directly from utils)
if tool_cls == TextRecognitionTool:
import tools.text_recognition as tr_module
monkeypatch.setattr(tr_module, "get_sdk_client", lambda *args: fake_client)
monkeypatch.setattr(tr_module, "cleanup_temp_file", lambda *args: None)
elif tool_cls == DocumentParsingTool:
import tools.document_parsing as dp_module
monkeypatch.setattr(dp_module, "get_sdk_client", lambda *args: fake_client)
monkeypatch.setattr(dp_module, "cleanup_temp_file", lambda *args: None)
else:
import tools.document_parsing_vl as dpv_module
monkeypatch.setattr(dpv_module, "get_sdk_client", lambda *args: fake_client)
monkeypatch.setattr(dpv_module, "cleanup_temp_file", lambda *args: None)

tool = tool_cls.from_credentials(credentials)
list(tool._invoke(parameters))
return captured
Expand All @@ -145,11 +213,11 @@ def test_text_recognition_sends_normalized_file_to_api(monkeypatch):
{"file": file, "fileType": "auto", "visualize": False},
)

assert captured["api_url"] == "https://example.com/text-recognition"
assert captured["access_token"] == "token"
assert captured["params"]["file"] == base64.b64encode(b"image-bytes").decode("utf-8")
assert captured["params"]["fileType"] == 1
assert captured["params"]["visualize"] is False
# SDK receives file_path (temp file), not base64 directly
assert "file_path" in captured["kwargs"]
assert captured["kwargs"]["file_path"] == "temp_file.png"
assert captured["kwargs"]["options"] is not None
assert hasattr(captured["kwargs"]["options"], "visualize")


def test_document_parsing_sends_normalized_file_to_api(monkeypatch):
Expand All @@ -167,10 +235,9 @@ def test_document_parsing_sends_normalized_file_to_api(monkeypatch):
{"file": file, "fileType": "auto", "markdownIgnoreLabels": "header, footer"},
)

assert captured["api_url"] == "https://example.com/document-parsing"
assert captured["params"]["file"] == base64.b64encode(b"%PDF-1.7").decode("utf-8")
assert captured["params"]["fileType"] == 0
assert captured["params"]["markdownIgnoreLabels"] == ["header", "footer"]
assert "file_path" in captured["kwargs"]
assert captured["kwargs"]["file_path"] == "temp_file.png"
assert captured["kwargs"]["options"] is not None


def test_document_parsing_vl_sends_normalized_file_to_api(monkeypatch):
Expand All @@ -192,10 +259,8 @@ def test_document_parsing_vl_sends_normalized_file_to_api(monkeypatch):
{"file": file, "fileType": "auto", "promptLabel": "undefined"},
)

assert captured["api_url"] == "https://example.com/document-parsing-vl"
assert captured["params"]["file"] == base64.b64encode(b"image-bytes").decode("utf-8")
assert captured["params"]["fileType"] == 1
assert "promptLabel" not in captured["params"]
assert "file_path" in captured["kwargs"]
assert captured["kwargs"]["file_path"] == "temp_file.png"


def load_tool_yaml(tool_name: str) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion tools/paddleocr/manifest.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: 0.2.6
version: 0.2.7
type: plugin
author: langgenius
name: paddleocr
Expand Down
51 changes: 21 additions & 30 deletions tools/paddleocr/provider/paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tools.document_parsing import DocumentParsingTool
from tools.document_parsing_vl import DocumentParsingVlTool
from tools.text_recognition import TextRecognitionTool
from tools.utils import call_paddleocr_api, get_sdk_client


class PaddleocrProvider(ToolProvider):
Expand All @@ -15,36 +16,26 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"AI Studio access token must be provided"
)

api_url_keys = (
"text_recognition_api_url",
"document_parsing_api_url",
"document_parsing_vl_api_url",
)
tool_classes = (
TextRecognitionTool,
DocumentParsingTool,
DocumentParsingVlTool,
)
# Get base_url (optional, uses SDK default if not provided)
base_url = credentials.get("base_url")

# Test with OCR (works for all models)
test_file = "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_002.png"

if not any(key in credentials for key in api_url_keys):
raise ToolProviderCredentialValidationError(
"You should provide at least one API URL"
try:
client_config = get_sdk_client(
access_token=credentials["aistudio_access_token"],
base_url=base_url,
)

for api_url_key, tool_cls in zip(api_url_keys, tool_classes):
if api_url_key in credentials:
try:
self._test_tool_validation(tool_cls, credentials, test_file)
except Exception as e:
raise ToolProviderCredentialValidationError(
f"Invalid credentials for {tool_cls.__name__}"
) from e

def _test_tool_validation(
self, tool_cls, credentials: dict[str, Any], test_file: str
) -> None:
tool = tool_cls.from_credentials(credentials)

for _ in tool.invoke(tool_parameters={"file": test_file}):
break
call_paddleocr_api(
model="PP-OCRv5",
file_url=test_file,
file_path=None,
options={},
client_config=client_config,
is_document_parsing=False,
)
except Exception as e:
raise ToolProviderCredentialValidationError(
f"Validation failed: {e}"
) from e
40 changes: 8 additions & 32 deletions tools/paddleocr/provider/paddleocr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,15 @@ credentials_for_provider:
en_US: Get your AI Studio access token
zh_Hans: 获取星河社区访问令牌
url: https://aistudio.baidu.com/index/accessToken
text_recognition_api_url:
base_url:
type: text-input
required: false
label:
en_US: Text Recognition API URL
zh_Hans: 文字识别 API URL
en_US: Base URL (Optional)
zh_Hans: Base URL(可选)
placeholder:
en_US: Text Recognition API URL
zh_Hans: 文字识别 API URL
en_US: https://paddleocr.aistudio-app.com
zh_Hans: https://paddleocr.aistudio-app.com
help:
en_US: Click the "API" button in the upper-left corner, select "Text recognition(PP-OCRv5)", and copy the `API_URL`.
zh_Hans: 点击左上角的“API”,选择“文字识别(PP-OCRv5)”并复制 `API_URL`
url: https://aistudio.baidu.com/paddleocr/task
document_parsing_api_url:
type: text-input
label:
en_US: Document Parsing API URL
zh_Hans: 文档解析 API URL
placeholder:
en_US: Document Parsing API URL
zh_Hans: 文档解析 API URL
help:
en_US: Click the "API" button in the upper-left corner, select "Document parsing(PP-StructureV3)", and copy the `API_URL`.
zh_Hans: 点击左上角的“API”,选择“文档解析(PP-StructureV3)”并复制 `API_URL`
url: https://aistudio.baidu.com/paddleocr/task
document_parsing_vl_api_url:
type: text-input
label:
en_US: Large Model Document Parsing API URL
zh_Hans: 大模型文档解析 API URL
placeholder:
en_US: Large Model Document Parsing API URL
zh_Hans: 大模型文档解析 API URL
help:
en_US: Click the "API" button in the upper-left corner, select "Large Model document parsing(PaddleOCR-VL)", and copy the `API_URL`.
zh_Hans: 点击左上角的“API”,选择“大模型文档解析(PaddleOCR-VL)”并复制 `API_URL`
url: https://aistudio.baidu.com/paddleocr/task
en_US: Leave empty to use the default PaddleOCR service. Only needed for self-hosted deployments.
zh_Hans: 留空则使用默认的 PaddleOCR 服务。仅自建服务时需要填写。
5 changes: 3 additions & 2 deletions tools/paddleocr/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
[project]
name = "paddleocr"
name = "paddleocr-dify"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"

# Managed with uv; refresh the lockfile with `uv lock`.
dependencies = [
"dify_plugin>=0.9.0",
"requests>=2.34.2",
]

# uv run black . -C -l 100 && uv run ruff check --fix
[dependency-groups]
dev = []
dev = []
Loading