diff --git a/openmanus_rl/engines/factory.py b/openmanus_rl/engines/factory.py index 939e9ac7..99fc2a54 100644 --- a/openmanus_rl/engines/factory.py +++ b/openmanus_rl/engines/factory.py @@ -1,20 +1,29 @@ """Engine factory helpers. Exposes `create_llm_engine` returning a callable that maps prompt -> text using -the minimal `ChatOpenAI` wrapper. Keep the surface small and stable so tools -can depend on it without heavy coupling. +the minimal `ChatOpenAI` wrapper (or `ChatLiteLLM` for 100+ providers via +LiteLLM). Keep the surface small and stable so tools can depend on it without +heavy coupling. """ from typing import Callable, Optional from .openai import ChatOpenAI -def create_llm_engine(model_string: str = "gpt-4o-mini", is_multimodal: bool = False, base_url: Optional[str] = None) -> Callable[[str], str]: - chat = ChatOpenAI(model=model_string, base_url=base_url) +def create_llm_engine( + model_string: str = "gpt-4o-mini", + is_multimodal: bool = False, + base_url: Optional[str] = None, + engine: str = "openai", + api_key: Optional[str] = None, +) -> Callable[[str], str]: + if engine == "litellm": + from .litellm import ChatLiteLLM + chat = ChatLiteLLM(model=model_string, base_url=base_url, api_key=api_key) + else: + chat = ChatOpenAI(model=model_string, base_url=base_url) def _engine(prompt: str) -> str: - # Tools currently call engine(prompt) for text-only flows. - # If multimodal is needed later, extend by adding optional image args. return chat(prompt) return _engine diff --git a/openmanus_rl/engines/litellm.py b/openmanus_rl/engines/litellm.py new file mode 100644 index 00000000..399b9305 --- /dev/null +++ b/openmanus_rl/engines/litellm.py @@ -0,0 +1,125 @@ +"""LiteLLM chat wrapper. + +Provides the same callable interface as ``ChatOpenAI`` but routes through +LiteLLM (https://github.com/BerriAI/litellm), giving access to 100+ LLM +providers (Anthropic, Bedrock, Vertex, Gemini, Cohere, Mistral, etc.) +via a single unified API. +""" + +import json +import re +from typing import Any, Dict, List, Optional, Type + +try: + from pydantic import BaseModel # type: ignore +except Exception: # pragma: no cover + BaseModel = object # type: ignore + + +class ChatLiteLLM: + """Thin wrapper around LiteLLM's completion API. + + Drop-in replacement for ``ChatOpenAI`` that supports any model string + LiteLLM understands (e.g. ``anthropic/claude-sonnet-4-6``, + ``bedrock/claude-3.5-sonnet``, ``gpt-4o``, etc.). + """ + + def __init__( + self, + model: str = "gpt-4o", + base_url: Optional[str] = None, + api_key: Optional[str] = None, + temperature: float = 0.0, + ) -> None: + try: + import litellm # type: ignore + except ImportError as exc: + raise RuntimeError( + "litellm package is not installed. " + "Install it with: pip install litellm" + ) from exc + + self._litellm = litellm + self.model = model + self.temperature = temperature + self.base_url = base_url + self.api_key = api_key + + def __call__( + self, + prompt: str, + images: Optional[List[str]] = None, + system: Optional[str] = None, + response_format: Optional[Type] = None, + **_: Any, + ) -> Any: + messages: List[Dict[str, Any]] = [] + if system: + messages.append({"role": "system", "content": system}) + + if not images: + messages.append({"role": "user", "content": prompt}) + else: + content = prompt + for p in images: + content += f"\n[Image: {p}]" + messages.append({"role": "user", "content": content}) + + kwargs: Dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": self.temperature, + "n": 1, + "drop_params": True, + } + if self.api_key: + kwargs["api_key"] = self.api_key + if self.base_url: + kwargs["base_url"] = self.base_url + + resp = self._litellm.completion(**kwargs) + text = (resp.choices[0].message.content or "").strip() + + try: + if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel): + try: + data = json.loads(text) + if isinstance(data, dict): + return response_format(**data) + if isinstance(data, list): + payload: Dict[str, Any] = {} + if hasattr(response_format, "model_fields") and "patch" in response_format.model_fields: + payload["patch"] = data + elif hasattr(response_format, "__fields__") and "patch" in getattr(response_format, "__fields__"): + payload["patch"] = data + if payload: + return response_format(**payload) + except Exception: + pass + + if getattr(response_format, "__name__", "") == "AnswerVerification": + analysis = "" + tf = False + m = re.search(r"\s*(.*?)\s*", text, re.DOTALL) + if m: + analysis = m.group(1).strip() + m2 = re.search(r"\s*(.*?)\s*", text, re.DOTALL) + if m2: + val = m2.group(1).strip().lower() + tf = val in ("true", "1", "yes") + if not analysis: + analysis = text + return response_format(analysis=analysis, true_false=tf) + + payload: Dict[str, Any] = {} + for field in ("analysis", "text"): + if (hasattr(response_format, "model_fields") and field in response_format.model_fields) or ( + hasattr(response_format, "__fields__") and field in getattr(response_format, "__fields__") + ): + payload[field] = text + if payload: + return response_format(**payload) + except Exception: + pass + + return text diff --git a/pyproject.toml b/pyproject.toml index 84babdbf..68bf5267 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ dependencies = [ test = [ "pytest", "yapf" ] +litellm = [ + "litellm>=1.80.0,<1.87.0" +] # URLs [project.urls]