diff --git a/.github/dependabot.yml b/.github/dependabot.yml index af0a8ed0ee4..ed93957ea4a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,12 +1,26 @@ version: 2 updates: + - package-ecosystem: uv + directory: '/' + schedule: + interval: monthly + target-branch: 'dev' + - package-ecosystem: pip directory: '/backend' schedule: interval: monthly target-branch: 'dev' + + - package-ecosystem: npm + directory: '/' + schedule: + interval: monthly + target-branch: 'dev' + - package-ecosystem: 'github-actions' directory: '/' schedule: # Check for updates to GitHub Actions every week interval: monthly + target-branch: 'dev' diff --git a/.github/workflows/format-backend.yaml b/.github/workflows/format-backend.yaml index 44587669753..1bcdd92c1db 100644 --- a/.github/workflows/format-backend.yaml +++ b/.github/workflows/format-backend.yaml @@ -5,10 +5,18 @@ on: branches: - main - dev + paths: + - 'backend/**' + - 'pyproject.toml' + - 'uv.lock' pull_request: branches: - main - dev + paths: + - 'backend/**' + - 'pyproject.toml' + - 'uv.lock' jobs: build: @@ -17,7 +25,9 @@ jobs: strategy: matrix: - python-version: [3.11] + python-version: + - 3.11.x + - 3.12.x steps: - uses: actions/checkout@v4 @@ -25,7 +35,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: '${{ matrix.python-version }}' - name: Install dependencies run: | diff --git a/.github/workflows/format-build-frontend.yaml b/.github/workflows/format-build-frontend.yaml index 53d3aaa5ec8..9a007581ffe 100644 --- a/.github/workflows/format-build-frontend.yaml +++ b/.github/workflows/format-build-frontend.yaml @@ -5,10 +5,18 @@ on: branches: - main - dev + paths-ignore: + - 'backend/**' + - 'pyproject.toml' + - 'uv.lock' pull_request: branches: - main - dev + paths-ignore: + - 'backend/**' + - 'pyproject.toml' + - 'uv.lock' jobs: build: @@ -21,7 +29,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v4 with: - node-version: '22' # Or specify any other version you want to use + node-version: '22' - name: Install Dependencies run: npm install diff --git a/CHANGELOG.md b/CHANGELOG.md index f6e8f7d297a..a11c2848ee6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,46 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.2] - 2025-04-06 + +### Added + +- 🌍 **Improved Global Language Support**: Expanded and refined translations across multiple languages to enhance clarity and consistency for international users. + +### Fixed + +- 🛠️ **Accurate Tool Descriptions from OpenAPI Servers**: External tools now use full endpoint descriptions instead of summaries when generating tool specifications—helping AI models understand tool purpose more precisely and choose the right tool more accurately in tool workflows. +- 🔧 **Precise Web Results Source Attribution**: Fixed a key issue where all web search results showed the same source ID—now each result gets its correct and distinct source, ensuring accurate citations and traceability. +- 🔍 **Clean Web Search Retrieval**: Web search now retains only results from URLs where real content was successfully fetched—improving accuracy and removing empty or broken links from citations. +- 🎵 **Audio File Upload Response Restored**: Resolved an issue where uploading audio files did not return valid responses, restoring smooth file handling for transcription and audio-based workflows. + +### Changed + +- 🧰 **General Backend Refactoring**: Multiple behind-the-scenes improvements streamline backend performance, reduce complexity, and ensure a more stable, maintainable system overall—making everything smoother without changing your workflow. + +## [0.6.1] - 2025-04-05 + +### Added + +- 🛠️ **Global Tool Servers Configuration**: Admins can now centrally configure global external tool servers from Admin Settings > Tools, allowing seamless sharing of tool integrations across all users without manual setup per user. +- 🔐 **Direct Tool Usage Permission for Users**: Introduced a new user-level permission toggle that grants non-admin users access to direct external tools, empowering broader team collaboration while maintaining control. +- 🧠 **Mistral OCR Content Extraction Support**: Added native support for Mistral OCR as a high-accuracy document loader, drastically improving text extraction from scanned documents in RAG workflows. +- 🖼️ **Tools Indicator UI Redesign**: Enhanced message input now smartly displays both built-in and external tools via a unified dropdown, making it simpler and more intuitive to activate tools during conversations. +- 📄 **RAG Prompt Improved and More Coherent**: Default RAG system prompt has been revised to be more clear and citation-focused—admins can leave the template field empty to use this new gold-standard prompt. +- 🧰 **Performance & Developer Improvements**: Major internal restructuring of several tool-related components, simplifying styling and merging external/internal handling logic, resulting in better maintainability and performance. +- 🌍 **Improved Translations**: Updated translations for Tibetan, Polish, Chinese (Simplified & Traditional), Arabic, Russian, Ukrainian, Dutch, Finnish, and French to improve clarity and consistency across the interface. + +### Fixed + +- 🔑 **External Tool Server API Key Bug Resolved**: Fixed a critical issue where authentication headers were not being sent when calling tools from external OpenAPI tool servers, ensuring full security and smooth tool operations. +- 🚫 **Conditional Export Button Visibility**: UI now gracefully hides export buttons when there's nothing to export in models, prompts, tools, or functions, improving visual clarity and reducing confusion. +- 🧪 **Hybrid Search Failure Recovery**: Resolved edge case in parallel hybrid search where empty or unindexed collections caused backend crashes—these are now cleanly skipped to ensure system stability. +- 📂 **Admin Folder Deletion Fix**: Addressed an issue where folders created in the admin workspace couldn't be deleted, restoring full organizational flexibility for admins. +- 🔐 **Improved Generic Error Feedback on Login**: Authentication errors now show simplified, non-revealing messages for privacy and improved UX, especially with federated logins. +- 📝 **Tool Message with Images Improved**: Enhanced how tool-generated messages with image outputs are shown in chat, making them more readable and consistent with the overall UI design. +- ⚙️ **Auto-Exclusion for Broken RAG Collections**: Auto-skips document collections that fail to fetch data or return "None", preventing silent errors and streamlining retrieval workflows. +- 📝 **Docling Text File Handling Fix**: Fixed file parsing inconsistency that broke docling-based RAG functionality for certain plain text files, ensuring wider file compatibility. + ## [0.6.0] - 2025-03-31 ### Added diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 0ac92bd23bd..8238f8a87ee 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -331,12 +331,14 @@ def __getattr__(self, key): # OAuth config #################################### + ENABLE_OAUTH_SIGNUP = PersistentConfig( "ENABLE_OAUTH_SIGNUP", "oauth.enable_signup", os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", ) + OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "oauth.merge_accounts_by_email", @@ -466,6 +468,7 @@ def __getattr__(self, key): os.environ.get("OAUTH_USERNAME_CLAIM", "name"), ) + OAUTH_PICTURE_CLAIM = PersistentConfig( "OAUTH_PICTURE_CLAIM", "oauth.oidc.avatar_claim", @@ -878,6 +881,17 @@ def oidc_oauth_register(client): pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" +#################################### +# TOOL_SERVERS +#################################### + + +TOOL_SERVER_CONNECTIONS = PersistentConfig( + "TOOL_SERVER_CONNECTIONS", + "tool_server.connections", + [], +) + #################################### # WEBUI #################################### @@ -1034,6 +1048,11 @@ def oidc_oauth_register(client): == "true" ) +USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS = ( + os.environ.get("USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS", "False").lower() + == "true" +) + USER_PERMISSIONS_FEATURES_WEB_SEARCH = ( os.environ.get("USER_PERMISSIONS_FEATURES_WEB_SEARCH", "True").lower() == "true" ) @@ -1071,6 +1090,7 @@ def oidc_oauth_register(client): "temporary_enforced": USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED, }, "features": { + "direct_tool_servers": USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS, "web_search": USER_PERMISSIONS_FEATURES_WEB_SEARCH, "image_generation": USER_PERMISSIONS_FEATURES_IMAGE_GENERATION, "code_interpreter": USER_PERMISSIONS_FEATURES_CODE_INTERPRETER, @@ -1727,6 +1747,11 @@ class BannerModel(BaseModel): os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""), ) +MISTRAL_OCR_API_KEY = PersistentConfig( + "MISTRAL_OCR_API_KEY", + "rag.mistral_ocr_api_key", + os.getenv("MISTRAL_OCR_API_KEY", ""), +) BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig( "BYPASS_EMBEDDING_AND_RETRIEVAL", @@ -1875,7 +1900,7 @@ class BannerModel(BaseModel): ) DEFAULT_RAG_TEMPLATE = """### Task: -Respond to the user query using the provided context, incorporating inline citations in the format [source_id] **only when the tag is explicitly provided** in the context. +Respond to the user query using the provided context, incorporating inline citations in the format [id] **only when the tag includes an explicit id attribute** (e.g., ). ### Guidelines: - If you don't know the answer, clearly state that. @@ -1883,18 +1908,17 @@ class BannerModel(BaseModel): - Respond in the same language as the user's query. - If the context is unreadable or of poor quality, inform the user and provide the best possible answer. - If the answer isn't present in the context but you possess the knowledge, explain this to the user and provide the answer using your own understanding. -- **Only include inline citations using [source_id] (e.g., [1], [2]) when a `` tag is explicitly provided in the context.** -- Do not cite if the tag is not provided in the context. +- **Only include inline citations using [id] (e.g., [1], [2]) when the tag includes an id attribute.** +- Do not cite if the tag does not contain an id attribute. - Do not use XML tags in your response. - Ensure citations are concise and directly related to the information provided. ### Example of Citation: -If the user asks about a specific topic and the information is found in "whitepaper.pdf" with a provided , the response should include the citation like so: -* "According to the study, the proposed method increases efficiency by 20% [whitepaper.pdf]." -If no is present, the response should omit the citation. +If the user asks about a specific topic and the information is found in a source with a provided id attribute, the response should include the citation like in the following example: +* "According to the study, the proposed method increases efficiency by 20% [1]." ### Output: -Provide a clear and direct response to the user's query, including inline citations in the format [source_id] only when the tag is present in the context. +Provide a clear and direct response to the user's query, including inline citations in the format [id] only when the tag with id attribute is present in the context. {{CONTEXT}} diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index bb78d900346..c9ca059c223 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -105,6 +105,8 @@ OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, + # Tool Server Configs + TOOL_SERVER_CONNECTIONS, # Code Execution ENABLE_CODE_EXECUTION, CODE_EXECUTION_ENGINE, @@ -191,6 +193,7 @@ DOCLING_SERVER_URL, DOCUMENT_INTELLIGENCE_ENDPOINT, DOCUMENT_INTELLIGENCE_KEY, + MISTRAL_OCR_API_KEY, RAG_TOP_K, RAG_TOP_K_RERANKER, RAG_TEXT_SPLITTER, @@ -355,6 +358,7 @@ from open_webui.utils.auth import ( get_license_data, + get_http_authorization_cred, decode_token, get_admin_user, get_verified_user, @@ -477,6 +481,15 @@ async def lifespan(app: FastAPI): app.state.OPENAI_MODELS = {} +######################################## +# +# TOOL SERVERS +# +######################################## + +app.state.config.TOOL_SERVER_CONNECTIONS = TOOL_SERVER_CONNECTIONS +app.state.TOOL_SERVERS = [] + ######################################## # # DIRECT CONNECTIONS @@ -582,6 +595,7 @@ async def lifespan(app: FastAPI): app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY +app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME @@ -862,6 +876,10 @@ async def commit_session_after_request(request: Request, call_next): @app.middleware("http") async def check_url(request: Request, call_next): start_time = int(time.time()) + request.state.token = get_http_authorization_cred( + request.headers.get("Authorization") + ) + request.state.enable_api_key = app.state.config.ENABLE_API_KEY response = await call_next(request) process_time = int(time.time()) - start_time diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index 295d0414a75..24944bd8a44 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -20,6 +20,9 @@ YoutubeLoader, ) from langchain_core.documents import Document + +from open_webui.retrieval.loaders.mistral import MistralLoader + from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -181,13 +184,16 @@ def load( for doc in docs ] + def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: + return file_ext in known_source_ext or ( + file_content_type and file_content_type.find("text/") >= 0 + ) + def _get_loader(self, filename: str, file_content_type: str, file_path: str): file_ext = filename.split(".")[-1].lower() if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): - if file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): + if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TikaLoader( @@ -196,11 +202,14 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): mime_type=file_content_type, ) elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"): - loader = DoclingLoader( - url=self.kwargs.get("DOCLING_SERVER_URL"), - file_path=file_path, - mime_type=file_content_type, - ) + if self._is_text_file(file_ext, file_content_type): + loader = TextLoader(file_path, autodetect_encoding=True) + else: + loader = DoclingLoader( + url=self.kwargs.get("DOCLING_SERVER_URL"), + file_path=file_path, + mime_type=file_content_type, + ) elif ( self.engine == "document_intelligence" and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" @@ -222,6 +231,15 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), ) + elif ( + self.engine == "mistral_ocr" + and self.kwargs.get("MISTRAL_OCR_API_KEY") != "" + and file_ext + in ["pdf"] # Mistral OCR currently only supports PDF and images + ): + loader = MistralLoader( + api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path + ) else: if file_ext == "pdf": loader = PyPDFLoader( @@ -257,9 +275,7 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): loader = UnstructuredPowerPointLoader(file_path) elif file_ext == "msg": loader = OutlookMessageLoader(file_path) - elif file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 - ): + elif self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: loader = TextLoader(file_path, autodetect_encoding=True) diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py new file mode 100644 index 00000000000..8f3a960a283 --- /dev/null +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -0,0 +1,225 @@ +import requests +import logging +import os +import sys +from typing import List, Dict, Any + +from langchain_core.documents import Document +from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL + +logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class MistralLoader: + """ + Loads documents by processing them through the Mistral OCR API. + """ + + BASE_API_URL = "https://api.mistral.ai/v1" + + def __init__(self, api_key: str, file_path: str): + """ + Initializes the loader. + + Args: + api_key: Your Mistral API key. + file_path: The local path to the PDF file to process. + """ + if not api_key: + raise ValueError("API key cannot be empty.") + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found at {file_path}") + + self.api_key = api_key + self.file_path = file_path + self.headers = {"Authorization": f"Bearer {self.api_key}"} + + def _handle_response(self, response: requests.Response) -> Dict[str, Any]: + """Checks response status and returns JSON content.""" + try: + response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) + # Handle potential empty responses for certain successful requests (e.g., DELETE) + if response.status_code == 204 or not response.content: + return {} # Return empty dict if no content + return response.json() + except requests.exceptions.HTTPError as http_err: + log.error(f"HTTP error occurred: {http_err} - Response: {response.text}") + raise + except requests.exceptions.RequestException as req_err: + log.error(f"Request exception occurred: {req_err}") + raise + except ValueError as json_err: # Includes JSONDecodeError + log.error(f"JSON decode error: {json_err} - Response: {response.text}") + raise # Re-raise after logging + + def _upload_file(self) -> str: + """Uploads the file to Mistral for OCR processing.""" + log.info("Uploading file to Mistral API") + url = f"{self.BASE_API_URL}/files" + file_name = os.path.basename(self.file_path) + + try: + with open(self.file_path, "rb") as f: + files = {"file": (file_name, f, "application/pdf")} + data = {"purpose": "ocr"} + + upload_headers = self.headers.copy() # Avoid modifying self.headers + + response = requests.post( + url, headers=upload_headers, files=files, data=data + ) + + response_data = self._handle_response(response) + file_id = response_data.get("id") + if not file_id: + raise ValueError("File ID not found in upload response.") + log.info(f"File uploaded successfully. File ID: {file_id}") + return file_id + except Exception as e: + log.error(f"Failed to upload file: {e}") + raise + + def _get_signed_url(self, file_id: str) -> str: + """Retrieves a temporary signed URL for the uploaded file.""" + log.info(f"Getting signed URL for file ID: {file_id}") + url = f"{self.BASE_API_URL}/files/{file_id}/url" + params = {"expiry": 1} + signed_url_headers = {**self.headers, "Accept": "application/json"} + + try: + response = requests.get(url, headers=signed_url_headers, params=params) + response_data = self._handle_response(response) + signed_url = response_data.get("url") + if not signed_url: + raise ValueError("Signed URL not found in response.") + log.info("Signed URL received.") + return signed_url + except Exception as e: + log.error(f"Failed to get signed URL: {e}") + raise + + def _process_ocr(self, signed_url: str) -> Dict[str, Any]: + """Sends the signed URL to the OCR endpoint for processing.""" + log.info("Processing OCR via Mistral API") + url = f"{self.BASE_API_URL}/ocr" + ocr_headers = { + **self.headers, + "Content-Type": "application/json", + "Accept": "application/json", + } + payload = { + "model": "mistral-ocr-latest", + "document": { + "type": "document_url", + "document_url": signed_url, + }, + "include_image_base64": False, + } + + try: + response = requests.post(url, headers=ocr_headers, json=payload) + ocr_response = self._handle_response(response) + log.info("OCR processing done.") + log.debug("OCR response: %s", ocr_response) + return ocr_response + except Exception as e: + log.error(f"Failed during OCR processing: {e}") + raise + + def _delete_file(self, file_id: str) -> None: + """Deletes the file from Mistral storage.""" + log.info(f"Deleting uploaded file ID: {file_id}") + url = f"{self.BASE_API_URL}/files/{file_id}" + # No specific Accept header needed, default or Authorization is usually sufficient + + try: + response = requests.delete(url, headers=self.headers) + delete_response = self._handle_response( + response + ) # Check status, ignore response body unless needed + log.info( + f"File deleted successfully: {delete_response}" + ) # Log the response if available + except Exception as e: + # Log error but don't necessarily halt execution if deletion fails + log.error(f"Failed to delete file ID {file_id}: {e}") + # Depending on requirements, you might choose to raise the error here + + def load(self) -> List[Document]: + """ + Executes the full OCR workflow: upload, get URL, process OCR, delete file. + + Returns: + A list of Document objects, one for each page processed. + """ + file_id = None + try: + # 1. Upload file + file_id = self._upload_file() + + # 2. Get Signed URL + signed_url = self._get_signed_url(file_id) + + # 3. Process OCR + ocr_response = self._process_ocr(signed_url) + + # 4. Process results + pages_data = ocr_response.get("pages") + if not pages_data: + log.warning("No pages found in OCR response.") + return [Document(page_content="No text content found", metadata={})] + + documents = [] + total_pages = len(pages_data) + for page_data in pages_data: + page_content = page_data.get("markdown") + page_index = page_data.get("index") # API uses 0-based index + + if page_content is not None and page_index is not None: + documents.append( + Document( + page_content=page_content, + metadata={ + "page": page_index, # 0-based index from API + "page_label": page_index + + 1, # 1-based label for convenience + "total_pages": total_pages, + # Add other relevant metadata from page_data if available/needed + # e.g., page_data.get('width'), page_data.get('height') + }, + ) + ) + else: + log.warning( + f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" + ) + + if not documents: + # Case where pages existed but none had valid markdown/index + log.warning( + "OCR response contained pages, but none had valid content/index." + ) + return [ + Document( + page_content="No text content found in valid pages", metadata={} + ) + ] + + return documents + + except Exception as e: + log.error(f"An error occurred during the loading process: {e}") + # Return an empty list or a specific error document on failure + return [Document(page_content=f"Error during processing: {e}", metadata={})] + finally: + # 5. Delete file (attempt even if prior steps failed after upload) + if file_id: + try: + self._delete_file(file_id) + except Exception as del_e: + # Log deletion error, but don't overwrite original error if one occurred + log.error( + f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" + ) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 518a1213679..12d48f86903 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -320,10 +320,13 @@ def process_query(collection_name, query): log.exception(f"Error when querying the collection with hybrid_search: {e}") return None, e + # Prepare tasks for all collections and queries + # Avoid running any tasks for collections that failed to fetch data (have assigned None) tasks = [ - (collection_name, query) - for collection_name in collection_names - for query in queries + (cn, q) + for cn in collection_names + if collection_results[cn] is not None + for q in queries ] with ThreadPoolExecutor() as executor: @@ -354,7 +357,7 @@ def get_embedding_function( ): if embedding_engine == "": return lambda query, prefix=None, user=None: embedding_function.encode( - query, prompt=prefix if prefix else None + query, **({"prompt": prefix} if prefix else {}) ).tolist() elif embedding_engine in ["ollama", "openai"]: func = lambda query, prefix=None, user=None: generate_embeddings( diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 34a63ba3faf..67c2e9f2a15 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -194,8 +194,8 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): ciphers=LDAP_CIPHERS, ) except Exception as e: - log.error(f"An error occurred on TLS: {str(e)}") - raise HTTPException(400, detail=str(e)) + log.error(f"TLS configuration error: {str(e)}") + raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.") try: server = Server( @@ -232,7 +232,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]) if not email or email == "" or email == "[]": - raise HTTPException(400, f"User {form_data.user} does not have email.") + raise HTTPException(400, "User does not have a valid email address.") else: email = email.lower() @@ -248,7 +248,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): authentication="SIMPLE", ) if not connection_user.bind(): - raise HTTPException(400, f"Authentication failed for {form_data.user}") + raise HTTPException(400, "Authentication failed.") user = Users.get_user_by_email(email) if not user: @@ -276,7 +276,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): except HTTPException: raise except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + log.error(f"LDAP user creation error: {str(err)}") + raise HTTPException( + 500, detail="Internal error occurred during LDAP user creation." + ) user = Auths.authenticate_user_by_trusted_header(email) @@ -312,12 +315,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: - raise HTTPException( - 400, - f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}", - ) + raise HTTPException(400, "User record mismatch.") except Exception as e: - raise HTTPException(400, detail=str(e)) + log.error(f"LDAP authentication error: {str(e)}") + raise HTTPException(400, detail="LDAP authentication failed.") ############################ @@ -519,7 +520,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm): else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + log.error(f"Signup error: {str(err)}") + raise HTTPException(500, detail="An internal error occurred during signup.") @router.get("/signout") @@ -547,7 +549,11 @@ async def signout(request: Request, response: Response): detail="Failed to fetch OpenID configuration", ) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + log.error(f"OpenID signout error: {str(e)}") + raise HTTPException( + status_code=500, + detail="Failed to sign out from the OpenID provider.", + ) return {"status": True} @@ -591,7 +597,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) except Exception as err: - raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) + log.error(f"Add user error: {str(err)}") + raise HTTPException( + 500, detail="An internal error occurred while adding the user." + ) ############################ @@ -764,11 +773,6 @@ async def update_ldap_server( if not value: raise HTTPException(400, detail=f"Required field {key} is empty") - if form_data.use_tls and not form_data.certificate_path: - raise HTTPException( - 400, detail="TLS is enabled but certificate file path is missing" - ) - request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_PORT = form_data.port diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 2a4c651f2a4..44b2ef40cfb 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,5 +1,5 @@ -from fastapi import APIRouter, Depends, Request -from pydantic import BaseModel +from fastapi import APIRouter, Depends, Request, HTTPException +from pydantic import BaseModel, ConfigDict from typing import Optional @@ -7,6 +7,8 @@ from open_webui.config import get_config, save_config from open_webui.config import BannerModel +from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data + router = APIRouter() @@ -66,6 +68,75 @@ async def set_direct_connections_config( } +############################ +# ToolServers Config +############################ + + +class ToolServerConnection(BaseModel): + url: str + path: str + auth_type: Optional[str] + key: Optional[str] + config: Optional[dict] + + model_config = ConfigDict(extra="allow") + + +class ToolServersConfigForm(BaseModel): + TOOL_SERVER_CONNECTIONS: list[ToolServerConnection] + + +@router.get("/tool_servers", response_model=ToolServersConfigForm) +async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)): + return { + "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + } + + +@router.post("/tool_servers", response_model=ToolServersConfigForm) +async def set_tool_servers_config( + request: Request, + form_data: ToolServersConfigForm, + user=Depends(get_admin_user), +): + request.app.state.config.TOOL_SERVER_CONNECTIONS = [ + connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS + ] + + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + + return { + "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, + } + + +@router.post("/tool_servers/verify") +async def verify_tool_servers_config( + request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user) +): + """ + Verify the connection to the tool server. + """ + try: + + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + + url = f"{form_data.url}/{form_data.path}" + return await get_tool_server_data(token, url) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to connect to the tool server: {str(e)}", + ) + + ############################ # CodeInterpreterConfig ############################ diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 22e1269e378..c30366545e2 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -122,6 +122,7 @@ def upload_file( ]: file_path = Storage.get_file(file_path) result = transcribe(request, file_path) + process_file( request, ProcessFileForm(file_id=id, content=result.get("text", "")), @@ -129,7 +130,8 @@ def upload_file( ) elif file.content_type not in ["image/png", "image/jpeg", "image/gif"]: process_file(request, ProcessFileForm(file_id=id), user=user) - file_item = Files.get_file_by_id(id=id) + + file_item = Files.get_file_by_id(id=id) except Exception as e: log.exception(e) log.error(f"Error processing file: {file_item.id}") @@ -162,11 +164,16 @@ def upload_file( @router.get("/", response_model=list[FileModelResponse]) -async def list_files(user=Depends(get_verified_user)): +async def list_files(user=Depends(get_verified_user), content: bool = Query(True)): if user.role == "admin": files = Files.get_files() else: files = Files.get_files_by_user_id(user.id) + + if not content: + for file in files: + del file.data["content"] + return files diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index cf37f9329da..2c41c92854b 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -236,7 +236,8 @@ async def delete_folder_by_id( chat_delete_permission = has_permission( user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ) - if not chat_delete_permission: + + if user.role != "admin" and not chat_delete_permission: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index fcb263d1e0a..775cd044656 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1197,7 +1197,7 @@ class OpenAIChatMessageContent(BaseModel): class OpenAIChatMessage(BaseModel): role: str - content: Union[str, list[OpenAIChatMessageContent]] + content: Union[Optional[str], list[OpenAIChatMessageContent]] model_config = ConfigDict(extra="allow") diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 2bd908606ac..f31abd9ff09 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -124,7 +124,7 @@ def get_ef( def get_rf( - reranking_model: str, + reranking_model: Optional[str] = None, auto_update: bool = False, ): rf = None @@ -150,8 +150,8 @@ def get_rf( device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, ) - except: - log.error("CrossEncoder error") + except Exception as e: + log.error(f"CrossEncoder: {e}") raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) return rf @@ -174,7 +174,7 @@ class ProcessUrlForm(CollectionNameForm): url: str -class SearchForm(CollectionNameForm): +class SearchForm(BaseModel): query: str @@ -364,6 +364,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, }, + "mistral_ocr_config": { + "api_key": request.app.state.config.MISTRAL_OCR_API_KEY, + }, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, @@ -427,11 +430,16 @@ class DocumentIntelligenceConfigForm(BaseModel): key: str +class MistralOCRConfigForm(BaseModel): + api_key: str + + class ContentExtractionConfig(BaseModel): engine: str = "" tika_server_url: Optional[str] = None docling_server_url: Optional[str] = None document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None + mistral_ocr_config: Optional[MistralOCRConfigForm] = None class ChunkParamUpdateForm(BaseModel): @@ -553,6 +561,10 @@ async def update_rag_config( request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = ( form_data.content_extraction.document_intelligence_config.key ) + if form_data.content_extraction.mistral_ocr_config is not None: + request.app.state.config.MISTRAL_OCR_API_KEY = ( + form_data.content_extraction.mistral_ocr_config.api_key + ) if form_data.chunk is not None: request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter @@ -659,6 +671,9 @@ async def update_rag_config( "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, }, + "mistral_ocr_config": { + "api_key": request.app.state.config.MISTRAL_OCR_API_KEY, + }, }, "chunk": { "text_splitter": request.app.state.config.TEXT_SPLITTER, @@ -747,6 +762,9 @@ async def update_query_settings( form_data.hybrid if form_data.hybrid else False ) + if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + request.app.state.rf = None + return { "status": True, "template": request.app.state.config.RAG_TEMPLATE, @@ -940,7 +958,7 @@ def process_file( if form_data.content: # Update the content in the file - # Usage: /files/{file_id}/data/content/update + # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline) try: # /files/{file_id}/data/content/update @@ -1007,6 +1025,7 @@ def process_file( PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, ) docs = loader.load( file.filename, file.meta.get("content_type"), file_path @@ -1445,12 +1464,6 @@ async def process_web_search( log.debug(f"web_results: {web_results}") try: - collection_name = form_data.collection_name - if collection_name == "" or collection_name is None: - collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[ - :63 - ] - urls = [result.link for result in web_results] loader = get_web_loader( urls, @@ -1459,6 +1472,9 @@ async def process_web_search( trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV, ) docs = await loader.aload() + urls = [ + doc.metadata["source"] for doc in docs + ] # only keep URLs which could be retrieved if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: return { @@ -1475,18 +1491,26 @@ async def process_web_search( "loaded_count": len(docs), } else: - await run_in_threadpool( - save_docs_to_vector_db, - request, - docs, - collection_name, - overwrite=True, - user=user, - ) + collection_names = [] + for doc_idx, doc in enumerate(docs): + collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[ + :63 + ] + + collection_names.append(collection_name) + + await run_in_threadpool( + save_docs_to_vector_db, + request, + [doc], + collection_name, + overwrite=True, + user=user, + ) return { "status": True, - "collection_name": collection_name, + "collection_names": collection_names, "filenames": urls, "loaded_count": len(docs), } @@ -1515,8 +1539,13 @@ def query_doc_handler( ): try: if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + collection_results = {} + collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( + collection_name=form_data.collection_name + ) return query_doc_with_hybrid_search( collection_name=form_data.collection_name, + collection_result=collection_results[form_data.collection_name], query=form_data.query, embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( query, prefix=prefix, user=user diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index b63c9732af9..39fca43d3ef 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -653,17 +653,6 @@ async def generate_moa_response( detail="Model not found", ) - # Check if the user has a custom task model - # If the user has a custom task model, use that model - task_model_id = get_task_model_id( - model_id, - request.app.state.config.TASK_MODEL, - request.app.state.config.TASK_MODEL_EXTERNAL, - models, - ) - - log.debug(f"generating MOA model {task_model_id} for user {user.email} ") - template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE content = moa_response_generation_template( @@ -673,7 +662,7 @@ async def generate_moa_response( ) payload = { - "model": task_model_id, + "model": model_id, "messages": [{"role": "user", "content": content}], "stream": form_data.get("stream", False), "metadata": { diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 211264cde07..8a98b4e2023 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,6 +1,7 @@ import logging from pathlib import Path from typing import Optional +import time from open_webui.models.tools import ( ToolForm, @@ -18,6 +19,8 @@ from open_webui.utils.access_control import has_access, has_permission from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.tools import get_tool_servers_data + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -30,11 +33,51 @@ @router.get("/", response_model=list[ToolUserResponse]) -async def get_tools(user=Depends(get_verified_user)): - if user.role == "admin": - tools = Tools.get_tools() - else: - tools = Tools.get_tools_by_user_id(user.id, "read") +async def get_tools(request: Request, user=Depends(get_verified_user)): + + if not request.app.state.TOOL_SERVERS: + # If the tool servers are not set, we need to set them + # This is done only once when the server starts + # This is done to avoid loading the tool servers every time + + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + + tools = Tools.get_tools() + for idx, server in enumerate(request.app.state.TOOL_SERVERS): + tools.append( + ToolUserResponse( + **{ + "id": f"server:{server['idx']}", + "user_id": f"server:{server['idx']}", + "name": server["openapi"] + .get("info", {}) + .get("title", "Tool Server"), + "meta": { + "description": server["openapi"] + .get("info", {}) + .get("description", ""), + }, + "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ + idx + ] + .get("config", {}) + .get("access_control", None), + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + ) + + if user.role != "admin": + tools = [ + tool + for tool in tools + if tool.user_id == user.id + or has_access(user.id, "read", tool.access_control) + ] + return tools diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 4cf9102e144..d1046bcedb8 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -93,6 +93,7 @@ class ChatPermissions(BaseModel): class FeaturesPermissions(BaseModel): + direct_tool_servers: bool = False web_search: bool = True image_generation: bool = True code_interpreter: bool = True diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 6dd3234b061..118ac049e2f 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -8,7 +8,9 @@ import os -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta +import pytz +from pytz import UTC from typing import Optional, Union, List, Dict from open_webui.models.users import Users @@ -141,12 +143,14 @@ def create_api_key(): return f"sk-{key}" -def get_http_authorization_cred(auth_header: str): +def get_http_authorization_cred(auth_header: Optional[str]): + if not auth_header: + return None try: scheme, credentials = auth_header.split(" ") return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) except Exception: - raise ValueError(ERROR_MESSAGES.INVALID_TOKEN) + return None def get_current_user( @@ -180,7 +184,12 @@ def get_current_user( ).split(",") ] - if request.url.path not in allowed_paths: + # Check if the request path matches any allowed endpoint. + if not any( + request.url.path == allowed + or request.url.path.startswith(allowed + "/") + for allowed in allowed_paths + ): raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED ) diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index a11aeb092c8..76c9db9eb1a 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -12,9 +12,9 @@ def get_sorted_filter_ids(model: dict): def get_priority(function_id): function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) + if function is not None: + valves = Functions.get_function_valves_by_id(function_id) + return valves.get("priority", 0) if valves else 0 return 0 filter_ids = [function.id for function in Functions.get_global_filter_functions()] diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 532f1738778..badae990651 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -221,13 +221,23 @@ async def tool_call_handler(tool_call): except Exception as e: tool_result = str(e) + tool_result_files = [] + if isinstance(tool_result, list): + for item in tool_result: + # check if string + if isinstance(item, str) and item.startswith("data:"): + tool_result_files.append(item) + tool_result.remove(item) + if isinstance(tool_result, dict) or isinstance(tool_result, list): tool_result = json.dumps(tool_result, indent=2) if isinstance(tool_result, str): tool = tools[tool_function_name] - tool_id = tool.get("toolkit_id", "") - if tool.get("citation", False) or tool.get("direct", False): + tool_id = tool.get("tool_id", "") + if tool.get("metadata", {}).get("citation", False) or tool.get( + "direct", False + ): sources.append( { @@ -238,7 +248,7 @@ async def tool_call_handler(tool_call): else f"{tool_function_name}" ), }, - "document": [tool_result], + "document": [tool_result, *tool_result_files], "metadata": [ { "source": ( @@ -254,7 +264,7 @@ async def tool_call_handler(tool_call): sources.append( { "source": {}, - "document": [tool_result], + "document": [tool_result, *tool_result_files], "metadata": [ { "source": ( @@ -267,7 +277,11 @@ async def tool_call_handler(tool_call): } ) - if tools[tool_function_name].get("file_handler", False): + if ( + tools[tool_function_name] + .get("metadata", {}) + .get("file_handler", False) + ): skip_files = True # check if "tool_calls" in result @@ -385,24 +399,44 @@ async def chat_web_search_handler( all_results.append(results) files = form_data.get("files", []) - if results.get("collection_name"): - files.append( - { - "collection_name": results["collection_name"], - "name": searchQuery, - "type": "web_search", - "urls": results["filenames"], - } - ) + if results.get("collection_names"): + for col_idx, collection_name in enumerate( + results.get("collection_names") + ): + files.append( + { + "collection_name": collection_name, + "name": searchQuery, + "type": "web_search", + "urls": [results["filenames"][col_idx]], + } + ) elif results.get("docs"): - files.append( - { - "docs": results.get("docs", []), - "name": searchQuery, - "type": "web_search", - "urls": results["filenames"], - } - ) + # Invoked when bypass embedding and retrieval is set to True + docs = results["docs"] + + if len(docs) == len(results["filenames"]): + # the number of docs and filenames (urls) should be the same + for doc_idx, doc in enumerate(docs): + files.append( + { + "docs": [doc], + "name": searchQuery, + "type": "web_search", + "urls": [results["filenames"][doc_idx]], + } + ) + else: + # edge case when the number of docs and filenames (urls) are not the same + # this should not happen, but if it does, we will just append the docs + files.append( + { + "docs": results.get("docs", []), + "name": searchQuery, + "type": "web_search", + "urls": results["filenames"], + } + ) form_data["files"] = files except Exception as e: @@ -625,27 +659,28 @@ def apply_params_to_form_data(form_data, model): if "keep_alive" in params: form_data["keep_alive"] = params["keep_alive"] else: - if "seed" in params: + if "seed" in params and params["seed"] is not None: form_data["seed"] = params["seed"] - if "stop" in params: + if "stop" in params and params["stop"] is not None: form_data["stop"] = params["stop"] - if "temperature" in params: + if "temperature" in params and params["temperature"] is not None: form_data["temperature"] = params["temperature"] - if "max_tokens" in params: + if "max_tokens" in params and params["max_tokens"] is not None: form_data["max_tokens"] = params["max_tokens"] - if "top_p" in params: + if "top_p" in params and params["top_p"] is not None: form_data["top_p"] = params["top_p"] - if "frequency_penalty" in params: + if "frequency_penalty" in params and params["frequency_penalty"] is not None: form_data["frequency_penalty"] = params["frequency_penalty"] - if "reasoning_effort" in params: + if "reasoning_effort" in params and params["reasoning_effort"] is not None: form_data["reasoning_effort"] = params["reasoning_effort"] - if "logit_bias" in params: + + if "logit_bias" in params and params["logit_bias"] is not None: try: form_data["logit_bias"] = json.loads( convert_logit_bias_input_to_json(params["logit_bias"]) @@ -865,7 +900,9 @@ async def process_chat_payload(request, form_data, user, metadata, model): for source_idx, source in enumerate(sources): if "document" in source: for doc_idx, doc_context in enumerate(source["document"]): - context_string += f"{source_idx + 1}{doc_context}\n" + context_string += ( + f'{doc_context}\n' + ) context_string = context_string.strip() prompt = get_last_user_message(form_data["messages"]) @@ -1198,13 +1235,15 @@ def serialize_content_blocks(content_blocks, raw=False): ) tool_result = None + tool_result_files = None for result in results: if tool_call_id == result.get("tool_call_id", ""): tool_result = result.get("content", None) + tool_result_files = result.get("files", None) break if tool_result: - tool_calls_display_content = f'{tool_calls_display_content}\n
\nTool Executed\n
' + tool_calls_display_content = f'{tool_calls_display_content}\n
\nTool Executed\n
\n' else: tool_calls_display_content = f'{tool_calls_display_content}\n
\nExecuting...\n
' @@ -1805,7 +1844,7 @@ async def stream_body_handler(response): await stream_body_handler(response) - MAX_TOOL_CALL_RETRIES = 5 + MAX_TOOL_CALL_RETRIES = 10 tool_call_retries = 0 while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES: @@ -1898,6 +1937,14 @@ async def stream_body_handler(response): except Exception as e: tool_result = str(e) + tool_result_files = [] + if isinstance(tool_result, list): + for item in tool_result: + # check if string + if isinstance(item, str) and item.startswith("data:"): + tool_result_files.append(item) + tool_result.remove(item) + if isinstance(tool_result, dict) or isinstance( tool_result, list ): @@ -1907,6 +1954,11 @@ async def stream_body_handler(response): { "tool_call_id": tool_call_id, "content": tool_result, + **( + {"files": tool_result_files} + if tool_result_files + else {} + ), } ) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index ab50247d8bd..9ebe0e6dcb5 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -326,40 +326,45 @@ async def handle_callback(self, request, provider, response): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM - picture_url = user_data.get( - picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") - ) - if picture_url: - # Download the profile image into a base64 string - try: - access_token = token.get("access_token") - get_kwargs = {} - if access_token: - get_kwargs["headers"] = { - "Authorization": f"Bearer {access_token}", - } - async with aiohttp.ClientSession() as session: - async with session.get(picture_url, **get_kwargs) as resp: - if resp.ok: - picture = await resp.read() - base64_encoded_picture = base64.b64encode( - picture - ).decode("utf-8") - guessed_mime_type = mimetypes.guess_type( - picture_url - )[0] - if guessed_mime_type is None: - # assume JPG, browsers are tolerant enough of image formats - guessed_mime_type = "image/jpeg" - picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - else: - picture_url = "/user.png" - except Exception as e: - log.error( - f"Error downloading profile image '{picture_url}': {e}" - ) + if picture_claim: + picture_url = user_data.get( + picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") + ) + if picture_url: + # Download the profile image into a base64 string + try: + access_token = token.get("access_token") + get_kwargs = {} + if access_token: + get_kwargs["headers"] = { + "Authorization": f"Bearer {access_token}", + } + async with aiohttp.ClientSession() as session: + async with session.get( + picture_url, **get_kwargs + ) as resp: + if resp.ok: + picture = await resp.read() + base64_encoded_picture = base64.b64encode( + picture + ).decode("utf-8") + guessed_mime_type = mimetypes.guess_type( + picture_url + )[0] + if guessed_mime_type is None: + # assume JPG, browsers are tolerant enough of image formats + guessed_mime_type = "image/jpeg" + picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" + else: + picture_url = "/user.png" + except Exception as e: + log.error( + f"Error downloading profile image '{picture_url}': {e}" + ) + picture_url = "/user.png" + if not picture_url: picture_url = "/user.png" - if not picture_url: + else: picture_url = "/user.png" username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 29a4d0cceb9..f0746da7792 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -68,23 +68,23 @@ def replace_imports(content): return content -def load_tools_module_by_id(toolkit_id, content=None): +def load_tools_module_by_id(tool_id, content=None): if content is None: - tool = Tools.get_tool_by_id(toolkit_id) + tool = Tools.get_tool_by_id(tool_id) if not tool: - raise Exception(f"Toolkit not found: {toolkit_id}") + raise Exception(f"Toolkit not found: {tool_id}") content = tool.content content = replace_imports(content) - Tools.update_tool_by_id(toolkit_id, {"content": content}) + Tools.update_tool_by_id(tool_id, {"content": content}) else: frontmatter = extract_frontmatter(content) # Install required packages found within the frontmatter install_frontmatter_requirements(frontmatter.get("requirements", "")) - module_name = f"tool_{toolkit_id}" + module_name = f"tool_{tool_id}" module = types.ModuleType(module_name) sys.modules[module_name] = module @@ -108,7 +108,7 @@ def load_tools_module_by_id(toolkit_id, content=None): else: raise Exception("No Tools class found in the module") except Exception as e: - log.error(f"Error loading module: {toolkit_id}: {e}") + log.error(f"Error loading module: {tool_id}: {e}") del sys.modules[module_name] # Clean up raise e finally: diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index bd2a731e6aa..734c23e1b04 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -2,9 +2,10 @@ import logging import re import inspect -import uuid +import aiohttp +import asyncio -from typing import Any, Awaitable, Callable, get_type_hints +from typing import Any, Awaitable, Callable, get_type_hints, Dict, List, Union, Optional from functools import update_wrapper, partial @@ -17,96 +18,162 @@ from open_webui.models.users import UserModel from open_webui.utils.plugin import load_tools_module_by_id +import copy + log = logging.getLogger(__name__) -def apply_extra_params_to_tool_function( +def get_async_tool_function_and_apply_extra_params( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: sig = inspect.signature(function) extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} partial_func = partial(function, **extra_params) + if inspect.iscoroutinefunction(function): update_wrapper(partial_func, function) return partial_func + else: + # Make it a coroutine function + async def new_function(*args, **kwargs): + return partial_func(*args, **kwargs) - async def new_function(*args, **kwargs): - return partial_func(*args, **kwargs) - - update_wrapper(new_function, function) - return new_function + update_wrapper(new_function, function) + return new_function -# Mutation on extra_params def get_tools( request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools_dict = {} for tool_id in tool_ids: - tools = Tools.get_tool_by_id(tool_id) - if tools is None: - continue - - module = request.app.state.TOOLS.get(tool_id, None) - if module is None: - module, _ = load_tools_module_by_id(tool_id) - request.app.state.TOOLS[tool_id] = module - - extra_params["__id__"] = tool_id - if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} - module.valves = module.Valves(**valves) - - if hasattr(module, "UserValves"): - extra_params["__user__"]["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) - ) - - for spec in tools.specs: - # TODO: Fix hack for OpenAI API - # Some times breaks OpenAI but others don't. Leaving the comment - for val in spec.get("parameters", {}).get("properties", {}).values(): - if val["type"] == "str": - val["type"] = "string" - - # Remove internal parameters - spec["parameters"]["properties"] = { - key: val - for key, val in spec["parameters"]["properties"].items() - if not key.startswith("__") - } - - function_name = spec["name"] - - # convert to function that takes only model params and inserts custom params - original_func = getattr(module, function_name) - callable = apply_extra_params_to_tool_function(original_func, extra_params) - - if callable.__doc__ and callable.__doc__.strip() != "": - s = re.split(":(param|return)", callable.__doc__, 1) - spec["description"] = s[0] - else: - spec["description"] = function_name - - # TODO: This needs to be a pydantic model - tool_dict = { - "spec": spec, - "callable": callable, - "toolkit_id": tool_id, - "pydantic_model": function_to_pydantic_model(callable), - # Misc info - "file_handler": hasattr(module, "file_handler") and module.file_handler, - "citation": hasattr(module, "citation") and module.citation, - } - - # TODO: if collision, prepend toolkit name - if function_name in tools_dict: - log.warning(f"Tool {function_name} already exists in another tools!") - log.warning(f"Collision between {tools} and {tool_id}.") - log.warning(f"Discarding {tools}.{function_name}") + tool = Tools.get_tool_by_id(tool_id) + if tool is None: + if tool_id.startswith("server:"): + server_idx = int(tool_id.split(":")[1]) + tool_server_connection = ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[server_idx] + ) + tool_server_data = request.app.state.TOOL_SERVERS[server_idx] + specs = tool_server_data.get("specs", []) + + for spec in specs: + function_name = spec["name"] + + auth_type = tool_server_connection.get("auth_type", "bearer") + token = None + + if auth_type == "bearer": + token = tool_server_connection.get("key", "") + elif auth_type == "session": + token = request.state.token.credentials + + def make_tool_function(function_name, token, tool_server_data): + async def tool_function(**kwargs): + print( + f"Executing tool function {function_name} with params: {kwargs}" + ) + return await execute_tool_server( + token=token, + url=tool_server_data["url"], + name=function_name, + params=kwargs, + server_data=tool_server_data, + ) + + return tool_function + + tool_function = make_tool_function( + function_name, token, tool_server_data + ) + + callable = get_async_tool_function_and_apply_extra_params( + tool_function, + {}, + ) + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + log.warning(f"Discarding {tool_id}.{function_name}") + else: + tools_dict[function_name] = tool_dict else: - tools_dict[function_name] = tool_dict + continue + else: + module = request.app.state.TOOLS.get(tool_id, None) + if module is None: + module, _ = load_tools_module_by_id(tool_id) + request.app.state.TOOLS[tool_id] = module + + extra_params["__id__"] = tool_id + + # Set valves for the tool + if hasattr(module, "valves") and hasattr(module, "Valves"): + valves = Tools.get_tool_valves_by_id(tool_id) or {} + module.valves = module.Valves(**valves) + if hasattr(module, "UserValves"): + extra_params["__user__"]["valves"] = module.UserValves( # type: ignore + **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + ) + + for spec in tool.specs: + # TODO: Fix hack for OpenAI API + # Some times breaks OpenAI but others don't. Leaving the comment + for val in spec.get("parameters", {}).get("properties", {}).values(): + if val["type"] == "str": + val["type"] = "string" + + # Remove internal reserved parameters (e.g. __id__, __user__) + spec["parameters"]["properties"] = { + key: val + for key, val in spec["parameters"]["properties"].items() + if not key.startswith("__") + } + + # convert to function that takes only model params and inserts custom params + function_name = spec["name"] + tool_function = getattr(module, function_name) + callable = get_async_tool_function_and_apply_extra_params( + tool_function, extra_params + ) + + # TODO: Support Pydantic models as parameters + if callable.__doc__ and callable.__doc__.strip() != "": + s = re.split(":(param|return)", callable.__doc__, 1) + spec["description"] = s[0] + else: + spec["description"] = function_name + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + # Misc info + "metadata": { + "file_handler": hasattr(module, "file_handler") + and module.file_handler, + "citation": hasattr(module, "citation") and module.citation, + }, + } + + # TODO: if collision, prepend toolkit name + if function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + log.warning(f"Discarding {tool_id}.{function_name}") + else: + tools_dict[function_name] = tool_dict return tools_dict @@ -214,6 +281,273 @@ def get_callable_attributes(tool: object) -> list[Callable]: def get_tools_specs(tool_class: object) -> list[dict]: - function_list = get_callable_attributes(tool_class) - models = map(function_to_pydantic_model, function_list) - return [convert_to_openai_function(tool) for tool in models] + function_model_list = map( + function_to_pydantic_model, get_callable_attributes(tool_class) + ) + return [ + convert_to_openai_function(function_model) + for function_model in function_model_list + ] + + +def resolve_schema(schema, components): + """ + Recursively resolves a JSON schema using OpenAPI components. + """ + if not schema: + return {} + + if "$ref" in schema: + ref_path = schema["$ref"] + ref_parts = ref_path.strip("#/").split("/") + resolved = components + for part in ref_parts[1:]: # Skip the initial 'components' + resolved = resolved.get(part, {}) + return resolve_schema(resolved, components) + + resolved_schema = copy.deepcopy(schema) + + # Recursively resolve inner schemas + if "properties" in resolved_schema: + for prop, prop_schema in resolved_schema["properties"].items(): + resolved_schema["properties"][prop] = resolve_schema( + prop_schema, components + ) + + if "items" in resolved_schema: + resolved_schema["items"] = resolve_schema(resolved_schema["items"], components) + + return resolved_schema + + +def convert_openapi_to_tool_payload(openapi_spec): + """ + Converts an OpenAPI specification into a custom tool payload structure. + + Args: + openapi_spec (dict): The OpenAPI specification as a Python dict. + + Returns: + list: A list of tool payloads. + """ + tool_payload = [] + + for path, methods in openapi_spec.get("paths", {}).items(): + for method, operation in methods.items(): + tool = { + "type": "function", + "name": operation.get("operationId"), + "description": operation.get( + "description", operation.get("summary", "No description available.") + ), + "parameters": {"type": "object", "properties": {}, "required": []}, + } + + # Extract path and query parameters + for param in operation.get("parameters", []): + param_name = param["name"] + param_schema = param.get("schema", {}) + tool["parameters"]["properties"][param_name] = { + "type": param_schema.get("type"), + "description": param_schema.get("description", ""), + } + if param.get("required"): + tool["parameters"]["required"].append(param_name) + + # Extract and resolve requestBody if available + request_body = operation.get("requestBody") + if request_body: + content = request_body.get("content", {}) + json_schema = content.get("application/json", {}).get("schema") + if json_schema: + resolved_schema = resolve_schema( + json_schema, openapi_spec.get("components", {}) + ) + + if resolved_schema.get("properties"): + tool["parameters"]["properties"].update( + resolved_schema["properties"] + ) + if "required" in resolved_schema: + tool["parameters"]["required"] = list( + set( + tool["parameters"]["required"] + + resolved_schema["required"] + ) + ) + elif resolved_schema.get("type") == "array": + tool["parameters"] = resolved_schema # special case for array + + tool_payload.append(tool) + + return tool_payload + + +async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if token: + headers["Authorization"] = f"Bearer {token}" + + error = None + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status != 200: + error_body = await response.json() + raise Exception(error_body) + res = await response.json() + except Exception as err: + print("Error:", err) + if isinstance(err, dict) and "detail" in err: + error = err["detail"] + else: + error = str(err) + raise Exception(error) + + data = { + "openapi": res, + "info": res.get("info", {}), + "specs": convert_openapi_to_tool_payload(res), + } + + print("Fetched data:", data) + return data + + +async def get_tool_servers_data( + servers: List[Dict[str, Any]], session_token: Optional[str] = None +) -> List[Dict[str, Any]]: + # Prepare list of enabled servers along with their original index + server_entries = [] + for idx, server in enumerate(servers): + if server.get("config", {}).get("enable"): + url_path = server.get("path", "openapi.json") + full_url = f"{server.get('url')}/{url_path}" + + auth_type = server.get("auth_type", "bearer") + token = None + + if auth_type == "bearer": + token = server.get("key", "") + elif auth_type == "session": + token = session_token + server_entries.append((idx, server, full_url, token)) + + # Create async tasks to fetch data + tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries] + + # Execute tasks concurrently + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Build final results with index and server metadata + results = [] + for (idx, server, url, _), response in zip(server_entries, responses): + if isinstance(response, Exception): + print(f"Failed to connect to {url} OpenAPI tool server") + continue + + results.append( + { + "idx": idx, + "url": server.get("url"), + "openapi": response.get("openapi"), + "info": response.get("info"), + "specs": response.get("specs"), + } + ) + + return results + + +async def execute_tool_server( + token: str, url: str, name: str, params: Dict[str, Any], server_data: Dict[str, Any] +) -> Any: + error = None + try: + openapi = server_data.get("openapi", {}) + paths = openapi.get("paths", {}) + + matching_route = None + for route_path, methods in paths.items(): + for http_method, operation in methods.items(): + if isinstance(operation, dict) and operation.get("operationId") == name: + matching_route = (route_path, methods) + break + if matching_route: + break + + if not matching_route: + raise Exception(f"No matching route found for operationId: {name}") + + route_path, methods = matching_route + + method_entry = None + for http_method, operation in methods.items(): + if operation.get("operationId") == name: + method_entry = (http_method.lower(), operation) + break + + if not method_entry: + raise Exception(f"No matching method found for operationId: {name}") + + http_method, operation = method_entry + + path_params = {} + query_params = {} + body_params = {} + + for param in operation.get("parameters", []): + param_name = param["name"] + param_in = param["in"] + if param_name in params: + if param_in == "path": + path_params[param_name] = params[param_name] + elif param_in == "query": + query_params[param_name] = params[param_name] + + final_url = f"{url}{route_path}" + for key, value in path_params.items(): + final_url = final_url.replace(f"{{{key}}}", str(value)) + + if query_params: + query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) + final_url = f"{final_url}?{query_string}" + + if operation.get("requestBody", {}).get("content"): + if params: + body_params = params + else: + raise Exception( + f"Request body expected for operation '{name}' but none found." + ) + + headers = {"Content-Type": "application/json"} + + if token: + headers["Authorization"] = f"Bearer {token}" + + async with aiohttp.ClientSession() as session: + request_method = getattr(session, http_method.lower()) + + if http_method in ["post", "put", "patch"]: + async with request_method( + final_url, json=body_params, headers=headers + ) as response: + if response.status >= 400: + text = await response.text() + raise Exception(f"HTTP error {response.status}: {text}") + return await response.json() + else: + async with request_method(final_url, headers=headers) as response: + if response.status >= 400: + text = await response.text() + raise Exception(f"HTTP error {response.status}: {text}") + return await response.json() + + except Exception as err: + error = str(err) + print("API Request Error:", error) + return {"error": error} diff --git a/backend/requirements.txt b/backend/requirements.txt index ca2ea50609d..499eae36dec 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,7 +1,7 @@ fastapi==0.115.7 uvicorn[standard]==0.34.0 pydantic==2.10.6 -python-multipart==0.0.18 +python-multipart==0.0.20 python-socketio==5.11.3 python-jose==3.4.0 @@ -54,8 +54,9 @@ elasticsearch==8.17.1 transformers sentence-transformers==3.3.1 +accelerate colbert-ai==0.2.21 -einops==0.8.0 +einops==0.8.1 ftfy==6.2.3 @@ -67,7 +68,7 @@ python-pptx==1.0.0 unstructured==0.16.17 nltk==3.9.1 Markdown==3.7 -pypandoc==1.13 +pypandoc==1.15 pandas==2.2.3 openpyxl==3.1.5 pyxlsb==1.0.10 @@ -83,6 +84,8 @@ opencv-python-headless==4.11.0.86 rapidocr-onnxruntime==1.3.24 rank-bm25==0.2.2 +onnxruntime==1.20.1 + faster-whisper==1.1.1 PyJWT[crypto]==2.10.1 diff --git a/package-lock.json b/package-lock.json index 6eb5064d793..360d02a39fe 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "open-webui", - "version": "0.6.0", + "version": "0.6.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "open-webui", - "version": "0.6.0", + "version": "0.6.2", "dependencies": { "@azure/msal-browser": "^4.5.0", "@codemirror/lang-javascript": "^6.2.2", diff --git a/package.json b/package.json index 465fbba0fe1..f670644df93 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "open-webui", - "version": "0.6.0", + "version": "0.6.2", "private": true, "scripts": { "dev": "npm run pyodide:fetch && vite dev --host", diff --git a/pyproject.toml b/pyproject.toml index 2e8537a7700..52260e45e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "openai", "anthropic", - "google-generativeai==0.7.2", + "google-generativeai==0.8.4", "tiktoken", "langchain==0.3.19", @@ -61,8 +61,9 @@ dependencies = [ "transformers", "sentence-transformers==3.3.1", + "accelerate", "colbert-ai==0.2.21", - "einops==0.8.0", + "einops==0.8.1", "ftfy==6.2.3", "pypdf==4.3.1", @@ -73,7 +74,7 @@ dependencies = [ "unstructured==0.16.17", "nltk==3.9.1", "Markdown==3.7", - "pypandoc==1.13", + "pypandoc==1.15", "pandas==2.2.3", "openpyxl==3.1.5", "pyxlsb==1.0.10", @@ -89,6 +90,8 @@ dependencies = [ "rapidocr-onnxruntime==1.3.24", "rank-bm25==0.2.2", + "onnxruntime==1.20.1", + "faster-whisper==1.1.1", "PyJWT[crypto]==2.10.1", diff --git a/src/app.css b/src/app.css index 4061d3b5eb3..86e8438f096 100644 --- a/src/app.css +++ b/src/app.css @@ -46,6 +46,14 @@ math { @apply rounded-lg; } +input::placeholder { + direction: auto; +} + +textarea::placeholder { + direction: auto; +} + .input-prose { @apply prose dark:prose-invert prose-headings:font-semibold prose-hr:my-4 prose-hr:border-gray-100 prose-hr:dark:border-gray-800 prose-p:my-0 prose-img:my-1 prose-headings:my-1 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-ul:-my-0 prose-ol:-my-0 prose-li:-my-0 whitespace-pre-line; } diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index f7f02c74054..5872303f6aa 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -115,6 +115,93 @@ export const setDirectConnectionsConfig = async (token: string, config: object) return res; }; +export const getToolServerConnections = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const setToolServerConnections = async (token: string, connections: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...connections + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const verifyToolServerConnection = async (token: string, connection: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/tool_servers/verify`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...connection + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getCodeExecutionConfig = async (token: string) => { let error = null; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 015e1272acc..cdd6887b2dc 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -262,7 +262,7 @@ export const stopTask = async (token: string, id: string) => { export const getToolServerData = async (token: string, url: string) => { let error = null; - const res = await fetch(`${url}/openapi.json`, { + const res = await fetch(`${url}`, { method: 'GET', headers: { Accept: 'application/json', @@ -304,10 +304,13 @@ export const getToolServersData = async (i18n, servers: object[]) => { servers .filter((server) => server?.config?.enable) .map(async (server) => { - const data = await getToolServerData(server?.key, server?.url).catch((err) => { + const data = await getToolServerData( + server?.key, + server?.url + '/' + (server?.path ?? 'openapi.json') + ).catch((err) => { toast.error( i18n.t(`Failed to connect to {{URL}} OpenAPI tool server`, { - URL: server?.url + URL: server?.url + '/' + (server?.path ?? 'openapi.json') }) ); return null; diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddServerModal.svelte index f9970e4e84d..1ce7369e44b 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddServerModal.svelte @@ -15,6 +15,9 @@ import Tooltip from '$lib/components/common/Tooltip.svelte'; import Switch from '$lib/components/common/Switch.svelte'; import Tags from './common/Tags.svelte'; + import { getToolServerData } from '$lib/apis'; + import { verifyToolServerConnection } from '$lib/apis/configs'; + import AccessControl from './workspace/common/AccessControl.svelte'; export let onSubmit: Function = () => {}; export let onDelete: Function = () => {}; @@ -22,14 +25,66 @@ export let show = false; export let edit = false; + export let direct = false; + export let connection = null; let url = ''; + let path = 'openapi.json'; + + let auth_type = 'bearer'; let key = ''; + + let accessControl = null; + let enable = true; let loading = false; + const verifyHandler = async () => { + if (url === '') { + toast.error($i18n.t('Please enter a valid URL')); + return; + } + + if (path === '') { + toast.error($i18n.t('Please enter a valid path')); + return; + } + + if (direct) { + const res = await getToolServerData( + auth_type === 'bearer' ? key : localStorage.token, + `${url}/${path}` + ).catch((err) => { + toast.error($i18n.t('Connection failed')); + }); + + if (res) { + toast.success($i18n.t('Connection successful')); + console.debug('Connection successful', res); + } + } else { + const res = await verifyToolServerConnection(localStorage.token, { + url, + path, + auth_type, + key, + config: { + enable: enable, + access_control: accessControl + } + }).catch((err) => { + toast.error($i18n.t('Connection failed')); + }); + + if (res) { + toast.success($i18n.t('Connection successful')); + console.debug('Connection successful', res); + } + } + }; + const submitHandler = async () => { loading = true; @@ -38,9 +93,12 @@ const connection = { url, + path, + auth_type, key, config: { - enable: enable + enable: enable, + access_control: accessControl } }; @@ -50,16 +108,24 @@ show = false; url = ''; + path = 'openapi.json'; key = ''; + auth_type = 'bearer'; + enable = true; + accessControl = null; }; const init = () => { if (connection) { url = connection.url; - key = connection.key; + path = connection?.path ?? 'openapi.json'; + + auth_type = connection?.auth_type ?? 'bearer'; + key = connection?.key ?? ''; enable = connection.config?.enable ?? true; + accessControl = connection.config?.access_control ?? null; } }; @@ -113,47 +179,113 @@
-
{$i18n.t('URL')}
+
+
{$i18n.t('URL')}
+
-
+
+ + + + + + + +
-
-
- - - +
+
/
+ +
- {$i18n.t(`WebUI will make requests to "{{url}}/openapi.json"`, { - url: url + {$i18n.t(`WebUI will make requests to "{{url}}"`, { + url: `${url}/${path}` })}
-
{$i18n.t('Key')}
- -
- +
{$i18n.t('Auth')}
+ +
+
+ +
+ +
+ {#if auth_type === 'bearer'} + + {:else if auth_type === 'session'} +
+ {$i18n.t('Forwards system user session credentials to authenticate')} +
+ {/if} +
+ + {#if !direct} +
+ +
+
+ +
+
+ {/if}
diff --git a/src/lib/components/admin/Evaluations/Feedbacks.svelte b/src/lib/components/admin/Evaluations/Feedbacks.svelte index 026755b8a32..726028664a4 100644 --- a/src/lib/components/admin/Evaluations/Feedbacks.svelte +++ b/src/lib/components/admin/Evaluations/Feedbacks.svelte @@ -115,7 +115,7 @@ {feedbacks.length}
-
+ {#if feedbacks.length > 0}
-
+ {/if}
- + if (_functions) { + let blob = new Blob([JSON.stringify(_functions)], { + type: 'application/json' + }); + saveAs(blob, `functions-export-${Date.now()}.json`); + } + }} + > +
{$i18n.t('Export Functions')}
+ +
+ + + +
+ + {/if}
diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 76e3ae59d55..c26604d6c05 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -20,6 +20,7 @@ import DocumentChartBar from '../icons/DocumentChartBar.svelte'; import Evaluations from './Settings/Evaluations.svelte'; import CodeExecution from './Settings/CodeExecution.svelte'; + import Tools from './Settings/Tools.svelte'; const i18n = getContext('i18n'); @@ -135,6 +136,32 @@
{$i18n.t('Evaluations')}
+ + + + + +
+ {#each servers as server, idx} + { + updateHandler(); + }} + onDelete={() => { + servers = servers.filter((_, i) => i !== idx); + updateHandler(); + }} + /> + {/each} +
+ +
+
+ {$i18n.t('Connect to your own OpenAPI compatible external tool servers.')} +
+
+ + + + + + {:else} +
+
+ +
+
+ {/if} + + +
+ +
+ diff --git a/src/lib/components/admin/Users/Groups.svelte b/src/lib/components/admin/Users/Groups.svelte index 15497cb205b..e287feb1d53 100644 --- a/src/lib/components/admin/Users/Groups.svelte +++ b/src/lib/components/admin/Users/Groups.svelte @@ -64,9 +64,10 @@ delete: true, edit: true, temporary: true, - temporary_enforced: true + temporary_enforced: false }, features: { + direct_tool_servers: false, web_search: true, image_generation: true, code_interpreter: true diff --git a/src/lib/components/admin/Users/Groups/EditGroupModal.svelte b/src/lib/components/admin/Users/Groups/EditGroupModal.svelte index e492cc9b6d2..5b6bf6aabc7 100644 --- a/src/lib/components/admin/Users/Groups/EditGroupModal.svelte +++ b/src/lib/components/admin/Users/Groups/EditGroupModal.svelte @@ -38,6 +38,12 @@ prompts: false, tools: false }, + sharing: { + public_models: false, + public_knowledge: false, + public_prompts: false, + public_tools: false + }, chat: { controls: true, file_upload: true, @@ -46,6 +52,7 @@ temporary: true }, features: { + direct_tool_servers: false, web_search: true, image_generation: true, code_interpreter: true diff --git a/src/lib/components/admin/Users/Groups/Permissions.svelte b/src/lib/components/admin/Users/Groups/Permissions.svelte index e1aa73f2a25..5dac0de94c6 100644 --- a/src/lib/components/admin/Users/Groups/Permissions.svelte +++ b/src/lib/components/admin/Users/Groups/Permissions.svelte @@ -25,9 +25,10 @@ edit: true, file_upload: true, temporary: true, - temporary_enforced: true + temporary_enforced: false }, features: { + direct_tool_servers: false, web_search: true, image_generation: true, code_interpreter: true @@ -295,6 +296,14 @@
{$i18n.t('Features Permissions')}
+
+
+ {$i18n.t('Direct Tool Servers')} +
+ + +
+
{$i18n.t('Web Search')} diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 275f76d29cf..ce2aa54f1c7 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -106,7 +106,7 @@ messages[idx] = data; } } else if (type === 'typing' && event.message_id === null) { - if (event.user.id === $user.id) { + if (event.user.id === $user?.id) { return; } diff --git a/src/lib/components/channel/MessageInput.svelte b/src/lib/components/channel/MessageInput.svelte index 9ee433e30cc..9f495a8de11 100644 --- a/src/lib/components/channel/MessageInput.svelte +++ b/src/lib/components/channel/MessageInput.svelte @@ -381,7 +381,7 @@ >
{#if files.length > 0}
diff --git a/src/lib/components/channel/Messages.svelte b/src/lib/components/channel/Messages.svelte index f8ff2f229c0..e1bc326b375 100644 --- a/src/lib/components/channel/Messages.svelte +++ b/src/lib/components/channel/Messages.svelte @@ -132,7 +132,7 @@ if ( (message?.reactions ?? []) .find((reaction) => reaction.name === name) - ?.user_ids?.includes($user.id) ?? + ?.user_ids?.includes($user?.id) ?? false ) { messages = messages.map((m) => { @@ -140,7 +140,7 @@ const reaction = m.reactions.find((reaction) => reaction.name === name); if (reaction) { - reaction.user_ids = reaction.user_ids.filter((id) => id !== $user.id); + reaction.user_ids = reaction.user_ids.filter((id) => id !== $user?.id); reaction.count = reaction.user_ids.length; if (reaction.count === 0) { @@ -167,12 +167,12 @@ const reaction = m.reactions.find((reaction) => reaction.name === name); if (reaction) { - reaction.user_ids.push($user.id); + reaction.user_ids.push($user?.id); reaction.count = reaction.user_ids.length; } else { m.reactions.push({ name: name, - user_ids: [$user.id], + user_ids: [$user?.id], count: 1 }); } diff --git a/src/lib/components/channel/Messages/Message.svelte b/src/lib/components/channel/Messages/Message.svelte index 0736a25129f..9989388060a 100644 --- a/src/lib/components/channel/Messages/Message.svelte +++ b/src/lib/components/channel/Messages/Message.svelte @@ -106,7 +106,7 @@ {/if} - {#if message.user_id === $user.id || $user.role === 'admin'} + {#if message.user_id === $user?.id || $user?.role === 'admin'}
- {#if $user.role === 'admin' || $user?.permissions.chat?.controls} + {#if $user?.role === 'admin' || $user?.permissions.chat?.controls}
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 9251ba4e4bc..0f42985a570 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -21,7 +21,12 @@ TTSWorker } from '$lib/stores'; - import { blobToFile, compressImage, createMessagesList, findWordIndices } from '$lib/utils'; + import { + blobToFile, + compressImage, + createMessagesList, + extractCurlyBraceWords + } from '$lib/utils'; import { transcribeAudio } from '$lib/apis/audio'; import { uploadFile } from '$lib/apis/files'; import { generateAutoCompletion } from '$lib/apis'; @@ -47,6 +52,7 @@ import CommandLine from '../icons/CommandLine.svelte'; import { KokoroWorker } from '$lib/workers/KokoroWorker'; import ToolServersModal from './ToolServersModal.svelte'; + import Wrench from '../icons/Wrench.svelte'; const i18n = getContext('i18n'); @@ -85,7 +91,7 @@ webSearchEnabled }); - let showToolServers = false; + let showTools = false; let loaded = false; let recording = false; @@ -348,7 +354,7 @@ - + {#if loaded}
@@ -392,38 +398,6 @@
- {#if selectedToolIds.length > 0} -
-
-
- - - - -
-
- {#each selectedToolIds.map((id) => { - return $tools ? $tools.find((t) => t.id === id) : { id: id, name: id }; - }) as tool, toolIdx (toolIdx)} - - {tool.name} - - - {#if toolIdx !== selectedToolIds.length - 1} - , - {/if} - {/each} -
-
-
- {/if} - {#if atSelectedModel !== undefined}
@@ -536,7 +510,7 @@ >
{#if files.length > 0}
@@ -631,6 +605,7 @@ {#if $settings?.richTextInput ?? true}
0) { const word = words.at(0); @@ -1057,8 +1033,8 @@ {/if}
-
-
+
+
-
+
+ {#if toolServers.length + selectedToolIds.length > 0} + + + + {/if} + {#if $_user} {#if $config?.features?.enable_web_search && ($_user.role === 'admin' || $_user?.permissions?.features?.web_search)} @@ -1140,7 +1139,7 @@ > @@ -1159,7 +1158,7 @@ > @@ -1178,7 +1177,7 @@ > @@ -1189,47 +1188,6 @@
- {#if toolServers.length > 0} - - - - {/if} - {#if !history?.currentId || history.messages[history.currentId]?.done == true}
@@ -268,7 +269,7 @@
-
+
{$i18n.t('API keys')}
diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index 67b1f4dc107..80823f83020 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -788,25 +788,23 @@
- {$i18n.t('Repeat Penalty (Ollama)')} + {$i18n.t('Repeat Last N')}
- {#if (params?.repeat_penalty ?? null) !== null} + {#if (params?.repeat_last_n ?? null) !== null}
@@ -844,23 +842,25 @@
- {$i18n.t('Repeat Last N')} + {$i18n.t('Tfs Z')}
- {#if (params?.repeat_last_n ?? null) !== null} + {#if (params?.tfs_z ?? null) !== null}
@@ -899,24 +899,24 @@
- {$i18n.t('Tfs Z')} + {$i18n.t('Tokens To Keep On Context Refresh (num_keep)')}
- {#if (params?.tfs_z ?? null) !== null} + {#if (params?.num_keep ?? null) !== null}
-
+
@@ -954,24 +953,25 @@
- {$i18n.t('Context Length')} - {$i18n.t('(Ollama)')} + {$i18n.t('Max Tokens (num_predict)')}
- {#if (params?.num_ctx ?? null) !== null} + {#if (params?.max_tokens ?? null) !== null}
-
+
@@ -1009,24 +1009,24 @@
- {$i18n.t('Batch Size (num_batch)')} + {$i18n.t('Repeat Penalty (Ollama)')}
- {#if (params?.num_batch ?? null) !== null} + {#if (params?.repeat_penalty ?? null) !== null}
@@ -1063,25 +1064,24 @@
- {$i18n.t('Tokens To Keep On Context Refresh (num_keep)')} + {$i18n.t('Context Length')} + {$i18n.t('(Ollama)')}
- {#if (params?.num_keep ?? null) !== null} + {#if (params?.num_ctx ?? null) !== null}
- {$i18n.t('Max Tokens (num_predict)')} + {$i18n.t('Batch Size (num_batch)')}
- {#if (params?.max_tokens ?? null) !== null} + {#if (params?.num_batch ?? null) !== null}
diff --git a/src/lib/components/chat/Settings/Audio.svelte b/src/lib/components/chat/Settings/Audio.svelte index 0131aaae0dd..9b896628d4b 100644 --- a/src/lib/components/chat/Settings/Audio.svelte +++ b/src/lib/components/chat/Settings/Audio.svelte @@ -293,7 +293,7 @@
@@ -330,7 +330,7 @@
diff --git a/src/lib/components/chat/Settings/Chats.svelte b/src/lib/components/chat/Settings/Chats.svelte index 2eaea240ffc..e7c424b0527 100644 --- a/src/lib/components/chat/Settings/Chats.svelte +++ b/src/lib/components/chat/Settings/Chats.svelte @@ -16,6 +16,7 @@ import { onMount, getContext } from 'svelte'; import { goto } from '$app/navigation'; import { toast } from 'svelte-sonner'; + import ArchivedChatsModal from '$lib/components/layout/Sidebar/ArchivedChatsModal.svelte'; const i18n = getContext('i18n'); @@ -26,6 +27,7 @@ let showArchiveConfirm = false; let showDeleteConfirm = false; + let showArchivedChatsModal = false; let chatImportInputElement: HTMLInputElement; @@ -95,8 +97,16 @@ await chats.set(await getChatList(localStorage.token, $currentChatPage)); scrollPaginationEnabled.set(true); }; + + const handleArchivedChatsChange = async () => { + currentChatPage.set(1); + await chats.set(await getChatList(localStorage.token, $currentChatPage)); + scrollPaginationEnabled.set(true); + }; + +
@@ -157,6 +167,32 @@
+ + {#if showArchiveConfirm}
diff --git a/src/lib/components/chat/Settings/General.svelte b/src/lib/components/chat/Settings/General.svelte index 4e7e9cd3691..928a469d58c 100644 --- a/src/lib/components/chat/Settings/General.svelte +++ b/src/lib/components/chat/Settings/General.svelte @@ -308,15 +308,16 @@
- {#if $user.role === 'admin' || $user?.permissions.chat?.controls} -
+ {#if $user?.role === 'admin' || $user?.permissions.chat?.controls} +
{$i18n.t('System Prompt')}
-