From 2461d81a7a7ac1c973cd49350dadfeb54648302f Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Wed, 29 Apr 2026 17:49:51 +0530 Subject: [PATCH 1/4] feat: add LiteLLM as AI gateway provider --- api/config.py | 9 +- api/config/generator.json | 15 ++ api/litellm_client.py | 234 ++++++++++++++++++++++++++++ api/pyproject.toml | 1 + tests/unit/test_litellm_client.py | 251 ++++++++++++++++++++++++++++++ 5 files changed, 507 insertions(+), 3 deletions(-) create mode 100644 api/litellm_client.py create mode 100644 tests/unit/test_litellm_client.py diff --git a/api/config.py b/api/config.py index 49dfcf7b0..2ca7b7007 100644 --- a/api/config.py +++ b/api/config.py @@ -13,6 +13,7 @@ from api.google_embedder_client import GoogleEmbedderClient from api.azureai_client import AzureAIClient from api.dashscope_client import DashscopeClient +from api.litellm_client import LiteLLMClient from adalflow import GoogleGenAIClient, OllamaClient # Get API keys from environment variables @@ -63,7 +64,8 @@ "OllamaClient": OllamaClient, "BedrockClient": BedrockClient, "AzureAIClient": AzureAIClient, - "DashscopeClient": DashscopeClient + "DashscopeClient": DashscopeClient, + "LiteLLMClient": LiteLLMClient, } def replace_env_placeholders(config: Union[Dict[str, Any], List[Any], str, Any]) -> Union[Dict[str, Any], List[Any], str, Any]: @@ -131,7 +133,7 @@ def load_generator_config(): if provider_config.get("client_class") in CLIENT_CLASSES: provider_config["model_client"] = CLIENT_CLASSES[provider_config["client_class"]] # Fall back to default mapping based on provider_id - elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "dashscope"]: + elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "dashscope", "litellm"]: default_map = { "google": GoogleGenAIClient, "openai": OpenAIClient, @@ -139,7 +141,8 @@ def load_generator_config(): "ollama": OllamaClient, "bedrock": BedrockClient, "azure": AzureAIClient, - "dashscope": DashscopeClient + "dashscope": DashscopeClient, + "litellm": LiteLLMClient, } provider_config["model_client"] = default_map[provider_id] else: diff --git a/api/config/generator.json b/api/config/generator.json index f88179098..0e6b774e8 100644 --- a/api/config/generator.json +++ b/api/config/generator.json @@ -193,6 +193,21 @@ "top_p": 0.8 } } + }, + "litellm": { + "client_class": "LiteLLMClient", + "default_model": "openai/gpt-4o", + "supportsCustomModel": true, + "models": { + "openai/gpt-4o": { + "temperature": 0.7, + "top_p": 0.8 + }, + "anthropic/claude-sonnet-4-20250514": { + "temperature": 0.7, + "top_p": 0.8 + } + } } } } diff --git a/api/litellm_client.py b/api/litellm_client.py new file mode 100644 index 000000000..83e1f38ea --- /dev/null +++ b/api/litellm_client.py @@ -0,0 +1,234 @@ +"""LiteLLM ModelClient integration. + +Routes to 100+ LLM providers via litellm.completion(). +Provider API keys are read from environment variables automatically +(OPENAI_API_KEY, ANTHROPIC_API_KEY, AWS_ACCESS_KEY_ID, GEMINI_API_KEY, etc.). + +Model names use LiteLLM format: "provider/model-name", e.g.: + anthropic/claude-sonnet-4-20250514, openai/gpt-4o, + bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0 + +See https://docs.litellm.ai/docs/providers for the full list. +""" + +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, +) + +import backoff + +from adalflow.core.model_client import ModelClient +from adalflow.core.types import ( + CompletionUsage, + EmbedderOutput, + GeneratorOutput, + ModelType, +) +from adalflow.components.model_client.utils import parse_embedding_response + +log = logging.getLogger(__name__) +T = TypeVar("T") + + +def get_first_message_content(completion) -> str: + return completion.choices[0].message.content + + +class LiteLLMClient(ModelClient): + __doc__ = r"""A component wrapper for the LiteLLM AI gateway. + + LiteLLM routes to 100+ LLM providers (OpenAI, Anthropic, Google, AWS Bedrock, + Azure, Ollama, etc.) through a single unified interface. Provider API keys are + read from environment variables automatically. + + Model names use LiteLLM format: ``provider/model-name``. + + Example: + ```python + from api.litellm_client import LiteLLMClient + import adalflow as adal + + client = LiteLLMClient() + generator = adal.Generator( + model_client=client, + model_kwargs={"model": "anthropic/claude-sonnet-4-20250514"} + ) + response = generator(prompt_kwargs={"input_str": "What is LLM?"}) + ``` + + Args: + api_key (Optional[str]): API key for the provider. If not provided, + LiteLLM reads from the provider's standard env var (e.g. ANTHROPIC_API_KEY). + base_url (Optional[str]): Custom API base URL (e.g. for LiteLLM proxy server). + chat_completion_parser: A function to parse the chat completion response. + Defaults to extracting the first message content. + """ + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + chat_completion_parser: Callable = None, + ): + super().__init__() + self._api_key = api_key + self._base_url = base_url + self.chat_completion_parser = ( + chat_completion_parser or get_first_message_content + ) + self.sync_client = self.init_sync_client() + self.async_client = None + + def init_sync_client(self): + return {"api_key": self._api_key, "base_url": self._base_url} + + def init_async_client(self): + return {"api_key": self._api_key, "base_url": self._base_url} + + def convert_inputs_to_api_kwargs( + self, + input: Optional[Any] = None, + model_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED, + ) -> Dict: + final_model_kwargs = model_kwargs.copy() + + if model_type == ModelType.EMBEDDER: + if isinstance(input, str): + input = [input] + if not isinstance(input, Sequence): + raise TypeError("input must be a sequence of text") + final_model_kwargs["input"] = input + elif model_type == ModelType.LLM: + messages: List[Dict[str, str]] = [] + if isinstance(input, str): + messages.append({"role": "user", "content": input}) + elif isinstance(input, list) and all(isinstance(m, dict) for m in input): + messages = input + else: + messages.append({"role": "user", "content": str(input)}) + final_model_kwargs["messages"] = messages + else: + raise ValueError(f"model_type {model_type} is not supported") + + return final_model_kwargs + + def parse_chat_completion(self, completion) -> GeneratorOutput: + try: + data = self.chat_completion_parser(completion) + except Exception as e: + log.error(f"Error parsing the completion: {e}") + return GeneratorOutput(data=None, error=str(e), raw_response=completion) + + try: + usage = self.track_completion_usage(completion) + return GeneratorOutput( + data=None, error=None, raw_response=data, usage=usage + ) + except Exception as e: + log.error(f"Error tracking the completion usage: {e}") + return GeneratorOutput(data=None, error=str(e), raw_response=data) + + def track_completion_usage(self, completion) -> CompletionUsage: + try: + return CompletionUsage( + completion_tokens=completion.usage.completion_tokens, + prompt_tokens=completion.usage.prompt_tokens, + total_tokens=completion.usage.total_tokens, + ) + except Exception as e: + log.error(f"Error tracking the completion usage: {e}") + return CompletionUsage( + completion_tokens=None, prompt_tokens=None, total_tokens=None + ) + + def parse_embedding_response(self, response) -> EmbedderOutput: + try: + return parse_embedding_response(response) + except Exception as e: + log.error(f"Error parsing the embedding response: {e}") + return EmbedderOutput(data=[], error=str(e), raw_response=response) + + @backoff.on_exception(backoff.expo, Exception, max_time=5) + def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + import litellm + + log.info(f"api_kwargs: {api_kwargs}") + + extra = {} + if self._api_key: + extra["api_key"] = self._api_key + if self._base_url: + extra["api_base"] = self._base_url + + if model_type == ModelType.EMBEDDER: + return litellm.embedding( + drop_params=True, + **api_kwargs, + **extra, + ) + elif model_type == ModelType.LLM: + if api_kwargs.get("stream", False): + self.chat_completion_parser = _handle_streaming_response + return litellm.completion( + drop_params=True, + **api_kwargs, + **extra, + ) + else: + return litellm.completion( + drop_params=True, + **api_kwargs, + **extra, + ) + else: + raise ValueError(f"model_type {model_type} is not supported") + + async def acall( + self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED + ): + import litellm + + extra = {} + if self._api_key: + extra["api_key"] = self._api_key + if self._base_url: + extra["api_base"] = self._base_url + + if model_type == ModelType.EMBEDDER: + return await litellm.aembedding( + drop_params=True, + **api_kwargs, + **extra, + ) + elif model_type == ModelType.LLM: + return await litellm.acompletion( + drop_params=True, + **api_kwargs, + **extra, + ) + else: + raise ValueError(f"model_type {model_type} is not supported") + + @classmethod + def from_dict(cls, data: Dict[str, Any]): + return cls(**data) + + def to_dict(self) -> Dict[str, Any]: + exclude = ["sync_client", "async_client"] + output = super().to_dict(exclude=exclude) + return output + + +def _handle_streaming_response(generator): + for completion in generator: + if completion.choices and completion.choices[0].delta.content: + yield completion.choices[0].delta.content diff --git a/api/pyproject.toml b/api/pyproject.toml index 09760f8b1..3f8328db5 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -28,6 +28,7 @@ boto3 = ">=1.34.0" websockets = ">=11.0.3" azure-identity = ">=1.12.0" azure-core = ">=1.24.0" +litellm = {version = ">=1.60.0,<2.0", optional = true} [build-system] diff --git a/tests/unit/test_litellm_client.py b/tests/unit/test_litellm_client.py new file mode 100644 index 000000000..7dc2732a4 --- /dev/null +++ b/tests/unit/test_litellm_client.py @@ -0,0 +1,251 @@ +"""Tests for the LiteLLM ModelClient integration.""" + +import sys +import types +from unittest import mock + +import pytest + +sys.path.insert(0, ".") +from api.litellm_client import LiteLLMClient +from adalflow.core.types import ModelType, CompletionUsage + + +def _make_mock_response(content="OK", prompt_tokens=10, completion_tokens=5): + """Build a mock LiteLLM ModelResponse.""" + msg = mock.MagicMock() + msg.content = content + msg.role = "assistant" + + choice = mock.MagicMock() + choice.message = msg + choice.index = 0 + choice.finish_reason = "stop" + + usage = mock.MagicMock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + usage.total_tokens = prompt_tokens + completion_tokens + + resp = mock.MagicMock() + resp.choices = [choice] + resp.usage = usage + return resp + + +def _make_mock_embedding_response(dims=1536): + """Build a mock LiteLLM EmbeddingResponse.""" + emb_data = mock.MagicMock() + emb_data.embedding = [0.1] * dims + emb_data.index = 0 + + resp = mock.MagicMock() + resp.data = [emb_data] + return resp + + +class TestLiteLLMClientInit: + def test_default_init(self): + client = LiteLLMClient() + assert client._api_key is None + assert client._base_url is None + assert client.sync_client is not None + + def test_init_with_params(self): + client = LiteLLMClient(api_key="test-key", base_url="https://proxy.example.com") + assert client._api_key == "test-key" + assert client._base_url == "https://proxy.example.com" + + +class TestConvertInputs: + def test_llm_string_input(self): + client = LiteLLMClient() + kwargs = client.convert_inputs_to_api_kwargs( + input="Hello", + model_kwargs={"model": "openai/gpt-4o"}, + model_type=ModelType.LLM, + ) + assert kwargs["messages"] == [{"role": "user", "content": "Hello"}] + assert kwargs["model"] == "openai/gpt-4o" + + def test_llm_message_list_input(self): + client = LiteLLMClient() + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ] + kwargs = client.convert_inputs_to_api_kwargs( + input=msgs, + model_kwargs={"model": "anthropic/claude-sonnet-4-20250514"}, + model_type=ModelType.LLM, + ) + assert kwargs["messages"] == msgs + assert kwargs["model"] == "anthropic/claude-sonnet-4-20250514" + + def test_embedder_string_input(self): + client = LiteLLMClient() + kwargs = client.convert_inputs_to_api_kwargs( + input="hello", + model_kwargs={"model": "text-embedding-3-small"}, + model_type=ModelType.EMBEDDER, + ) + assert kwargs["input"] == ["hello"] + + def test_embedder_list_input(self): + client = LiteLLMClient() + kwargs = client.convert_inputs_to_api_kwargs( + input=["hello", "world"], + model_kwargs={"model": "text-embedding-3-small"}, + model_type=ModelType.EMBEDDER, + ) + assert kwargs["input"] == ["hello", "world"] + + def test_unsupported_model_type(self): + client = LiteLLMClient() + with pytest.raises(ValueError, match="not supported"): + client.convert_inputs_to_api_kwargs( + input="x", model_kwargs={}, model_type=ModelType.IMAGE_GENERATION + ) + + +class TestCallMocked: + def test_completion_dispatches_correctly(self): + client = LiteLLMClient() + mock_resp = _make_mock_response("test response") + + fake_litellm = types.ModuleType("litellm") + fake_litellm.completion = mock.MagicMock(return_value=mock_resp) + fake_litellm.embedding = mock.MagicMock() + fake_litellm.acompletion = mock.AsyncMock() + fake_litellm.aembedding = mock.AsyncMock() + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + kwargs = { + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "hi"}], + } + result = client.call(api_kwargs=kwargs, model_type=ModelType.LLM) + + fake_litellm.completion.assert_called_once() + call_kwargs = fake_litellm.completion.call_args + assert call_kwargs.kwargs["drop_params"] is True + assert call_kwargs.kwargs["model"] == "openai/gpt-4o" + assert result.choices[0].message.content == "test response" + + def test_embedding_dispatches_correctly(self): + client = LiteLLMClient() + mock_resp = _make_mock_embedding_response() + + fake_litellm = types.ModuleType("litellm") + fake_litellm.completion = mock.MagicMock() + fake_litellm.embedding = mock.MagicMock(return_value=mock_resp) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + kwargs = { + "model": "text-embedding-3-small", + "input": ["hello"], + } + result = client.call(api_kwargs=kwargs, model_type=ModelType.EMBEDDER) + + fake_litellm.embedding.assert_called_once() + emb_call_kwargs = fake_litellm.embedding.call_args + assert emb_call_kwargs.kwargs["drop_params"] is True + + def test_api_key_forwarded_when_set(self): + client = LiteLLMClient(api_key="sk-test123") + mock_resp = _make_mock_response() + + fake_litellm = types.ModuleType("litellm") + fake_litellm.completion = mock.MagicMock(return_value=mock_resp) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + kwargs = { + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "hi"}], + } + client.call(api_kwargs=kwargs, model_type=ModelType.LLM) + call_kwargs = fake_litellm.completion.call_args + assert call_kwargs.kwargs["api_key"] == "sk-test123" + + def test_api_key_omitted_when_blank(self): + client = LiteLLMClient() + mock_resp = _make_mock_response() + + fake_litellm = types.ModuleType("litellm") + fake_litellm.completion = mock.MagicMock(return_value=mock_resp) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + kwargs = { + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "hi"}], + } + client.call(api_kwargs=kwargs, model_type=ModelType.LLM) + call_kwargs = fake_litellm.completion.call_args + assert "api_key" not in call_kwargs.kwargs + + def test_base_url_forwarded_when_set(self): + client = LiteLLMClient(base_url="https://proxy.local") + mock_resp = _make_mock_response() + + fake_litellm = types.ModuleType("litellm") + fake_litellm.completion = mock.MagicMock(return_value=mock_resp) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + kwargs = { + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "hi"}], + } + client.call(api_kwargs=kwargs, model_type=ModelType.LLM) + call_kwargs = fake_litellm.completion.call_args + assert call_kwargs.kwargs["api_base"] == "https://proxy.local" + + +class TestParseCompletion: + def test_parse_chat_completion(self): + client = LiteLLMClient() + mock_resp = _make_mock_response("Hello world", 10, 5) + + output = client.parse_chat_completion(mock_resp) + assert output.raw_response == "Hello world" + assert output.usage.completion_tokens == 5 + assert output.usage.prompt_tokens == 10 + + def test_track_usage(self): + client = LiteLLMClient() + mock_resp = _make_mock_response("x", 20, 30) + + usage = client.track_completion_usage(mock_resp) + assert usage.prompt_tokens == 20 + assert usage.completion_tokens == 30 + assert usage.total_tokens == 50 + + +class TestSerialization: + def test_from_dict(self): + client = LiteLLMClient.from_dict({"api_key": "test", "base_url": "http://x"}) + assert client._api_key == "test" + assert client._base_url == "http://x" + + def test_to_dict_excludes_clients(self): + client = LiteLLMClient() + d = client.to_dict() + assert "sync_client" not in str(d) + + +class TestConfigRegistration: + def test_litellm_in_client_classes(self): + pytest.importorskip("boto3") + from api.config import CLIENT_CLASSES + + assert "LiteLLMClient" in CLIENT_CLASSES + assert CLIENT_CLASSES["LiteLLMClient"] is LiteLLMClient + + def test_litellm_provider_in_generator_config(self): + import json + from pathlib import Path + + config_path = Path("api/config/generator.json") + config = json.loads(config_path.read_text()) + assert "litellm" in config["providers"] + assert config["providers"]["litellm"]["client_class"] == "LiteLLMClient" + assert config["providers"]["litellm"]["supportsCustomModel"] is True From fcf1fa8d609c6df94e62f1dc8782bb9b3cdc75cb Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Wed, 29 Apr 2026 19:55:19 +0530 Subject: [PATCH 2/4] fix: wire LiteLLM into serving paths, fix retry logic, add async/streaming tests --- api/litellm_client.py | 115 ++++++++------ api/simple_chat.py | 57 +++++++ api/websocket_wiki.py | 61 ++++++- tests/unit/test_litellm_client.py | 256 +++++++++++++++++++++++++----- 4 files changed, 399 insertions(+), 90 deletions(-) diff --git a/api/litellm_client.py b/api/litellm_client.py index 83e1f38ea..586ac5203 100644 --- a/api/litellm_client.py +++ b/api/litellm_client.py @@ -12,6 +12,7 @@ """ import logging +import re from typing import ( Any, Callable, @@ -38,10 +39,32 @@ T = TypeVar("T") +def _is_retryable(exc: BaseException) -> bool: + qualname = f"{type(exc).__module__}.{type(exc).__name__}" + return qualname in { + "litellm.exceptions.RateLimitError", + "litellm.exceptions.ServiceUnavailableError", + "litellm.exceptions.Timeout", + "litellm.exceptions.APIConnectionError", + "litellm.exceptions.InternalServerError", + } + + def get_first_message_content(completion) -> str: return completion.choices[0].message.content +def handle_streaming_response(generator): + for completion in generator: + choices = getattr(completion, "choices", []) + if choices: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text is not None: + yield text + + class LiteLLMClient(ModelClient): __doc__ = r"""A component wrapper for the LiteLLM AI gateway. @@ -70,6 +93,8 @@ class LiteLLMClient(ModelClient): base_url (Optional[str]): Custom API base URL (e.g. for LiteLLM proxy server). chat_completion_parser: A function to parse the chat completion response. Defaults to extracting the first message content. + input_type: How the input prompt is formatted. Use "messages" when the + adalflow Generator sends tagged system/user prompts. """ def __init__( @@ -77,10 +102,12 @@ def __init__( api_key: Optional[str] = None, base_url: Optional[str] = None, chat_completion_parser: Callable = None, + input_type: str = "text", ): super().__init__() self._api_key = api_key self._base_url = base_url + self._input_type = input_type self.chat_completion_parser = ( chat_completion_parser or get_first_message_content ) @@ -96,10 +123,10 @@ def init_async_client(self): def convert_inputs_to_api_kwargs( self, input: Optional[Any] = None, - model_kwargs: Dict = {}, + model_kwargs: Optional[Dict] = None, model_type: ModelType = ModelType.UNDEFINED, ) -> Dict: - final_model_kwargs = model_kwargs.copy() + final_model_kwargs = (model_kwargs or {}).copy() if model_type == ModelType.EMBEDDER: if isinstance(input, str): @@ -109,12 +136,30 @@ def convert_inputs_to_api_kwargs( final_model_kwargs["input"] = input elif model_type == ModelType.LLM: messages: List[Dict[str, str]] = [] - if isinstance(input, str): - messages.append({"role": "user", "content": input}) - elif isinstance(input, list) and all(isinstance(m, dict) for m in input): - messages = input - else: - messages.append({"role": "user", "content": str(input)}) + + if self._input_type == "messages" and isinstance(input, str): + system_start_tag = "" + system_end_tag = "" + user_start_tag = "" + user_end_tag = "" + + pattern = ( + rf"{system_start_tag}\s*(.*?)\s*{system_end_tag}\s*" + rf"{user_start_tag}\s*(.*?)\s*{user_end_tag}" + ) + match = re.compile(pattern, re.DOTALL).match(input) + if match: + messages.append({"role": "system", "content": match.group(1)}) + messages.append({"role": "user", "content": match.group(2)}) + + if not messages: + if isinstance(input, str): + messages.append({"role": "user", "content": input}) + elif isinstance(input, list) and all(isinstance(m, dict) for m in input): + messages = input + else: + messages.append({"role": "user", "content": str(input)}) + final_model_kwargs["messages"] = messages else: raise ValueError(f"model_type {model_type} is not supported") @@ -157,64 +202,44 @@ def parse_embedding_response(self, response) -> EmbedderOutput: log.error(f"Error parsing the embedding response: {e}") return EmbedderOutput(data=[], error=str(e), raw_response=response) - @backoff.on_exception(backoff.expo, Exception, max_time=5) - def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + @backoff.on_exception(backoff.expo, Exception, max_time=5, giveup=lambda e: not _is_retryable(e)) + def call(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelType.UNDEFINED): import litellm + api_kwargs = api_kwargs or {} log.info(f"api_kwargs: {api_kwargs}") - extra = {} + extra: Dict[str, Any] = {} if self._api_key: extra["api_key"] = self._api_key if self._base_url: extra["api_base"] = self._base_url if model_type == ModelType.EMBEDDER: - return litellm.embedding( - drop_params=True, - **api_kwargs, - **extra, - ) + return litellm.embedding(drop_params=True, **api_kwargs, **extra) elif model_type == ModelType.LLM: if api_kwargs.get("stream", False): - self.chat_completion_parser = _handle_streaming_response - return litellm.completion( - drop_params=True, - **api_kwargs, - **extra, - ) - else: - return litellm.completion( - drop_params=True, - **api_kwargs, - **extra, - ) + self.chat_completion_parser = handle_streaming_response + return litellm.completion(drop_params=True, **api_kwargs, **extra) else: raise ValueError(f"model_type {model_type} is not supported") - async def acall( - self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED - ): + @backoff.on_exception(backoff.expo, Exception, max_time=5, giveup=lambda e: not _is_retryable(e)) + async def acall(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelType.UNDEFINED): import litellm - extra = {} + api_kwargs = api_kwargs or {} + + extra: Dict[str, Any] = {} if self._api_key: extra["api_key"] = self._api_key if self._base_url: extra["api_base"] = self._base_url if model_type == ModelType.EMBEDDER: - return await litellm.aembedding( - drop_params=True, - **api_kwargs, - **extra, - ) + return await litellm.aembedding(drop_params=True, **api_kwargs, **extra) elif model_type == ModelType.LLM: - return await litellm.acompletion( - drop_params=True, - **api_kwargs, - **extra, - ) + return await litellm.acompletion(drop_params=True, **api_kwargs, **extra) else: raise ValueError(f"model_type {model_type} is not supported") @@ -226,9 +251,3 @@ def to_dict(self) -> Dict[str, Any]: exclude = ["sync_client", "async_client"] output = super().to_dict(exclude=exclude) return output - - -def _handle_streaming_response(generator): - for completion in generator: - if completion.choices and completion.choices[0].delta.content: - yield completion.choices[0].delta.content diff --git a/api/simple_chat.py b/api/simple_chat.py index 41a184ed8..dbe00ac97 100644 --- a/api/simple_chat.py +++ b/api/simple_chat.py @@ -18,6 +18,7 @@ from api.bedrock_client import BedrockClient from api.azureai_client import AzureAIClient from api.dashscope_client import DashscopeClient +from api.litellm_client import LiteLLMClient from api.rag import RAG from api.prompts import ( DEEP_RESEARCH_FIRST_ITERATION_PROMPT, @@ -444,6 +445,23 @@ async def chat_completions_stream(request: ChatCompletionRequest): "top_p": model_config["top_p"], } + api_kwargs = model.convert_inputs_to_api_kwargs( + input=prompt, + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + elif request.provider == "litellm": + logger.info(f"Using LiteLLM with model: {request.model}") + + model = LiteLLMClient() + model_kwargs = { + "model": request.model, + "stream": True, + "temperature": model_config.get("temperature", 0.7), + } + if "top_p" in model_config: + model_kwargs["top_p"] = model_config["top_p"] + api_kwargs = model.convert_inputs_to_api_kwargs( input=prompt, model_kwargs=model_kwargs, @@ -549,6 +567,25 @@ async def response_stream(): "Please check that you have set the DASHSCOPE_API_KEY (and optionally " "DASHSCOPE_WORKSPACE_ID) environment variables with valid values." ) + elif request.provider == "litellm": + try: + logger.info("Making LiteLLM API call") + response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) + async for chunk in response: + choices = getattr(chunk, "choices", []) + if len(choices) > 0: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text is not None: + yield text + except Exception as e_litellm: + logger.error(f"Error with LiteLLM API: {str(e_litellm)}") + yield ( + f"\nError with LiteLLM API: {str(e_litellm)}\n\n" + "Please check that your provider API keys are set as environment variables " + "and that the model name uses LiteLLM format (provider/model-name)." + ) else: # Google Generative AI (default provider) response = model.generate_content(prompt, stream=True) @@ -710,6 +747,26 @@ async def response_stream(): "Please check that you have set the DASHSCOPE_API_KEY (and optionally " "DASHSCOPE_WORKSPACE_ID) environment variables with valid values." ) + elif request.provider == "litellm": + try: + fallback_api_kwargs = model.convert_inputs_to_api_kwargs( + input=simplified_prompt, + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + logger.info("Making fallback LiteLLM API call") + fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) + async for chunk in fallback_response: + choices = getattr(chunk, "choices", []) + if len(choices) > 0: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text is not None: + yield text + except Exception as e_fallback: + logger.error(f"Error with LiteLLM API fallback: {str(e_fallback)}") + yield f"\nError with LiteLLM API fallback: {str(e_fallback)}" else: # Google Generative AI fallback (default provider) model_config = get_model_config(request.provider, request.model) diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py index d1a6b1bd5..09d0a918d 100644 --- a/api/websocket_wiki.py +++ b/api/websocket_wiki.py @@ -23,6 +23,7 @@ from api.openrouter_client import OpenRouterClient from api.azureai_client import AzureAIClient from api.dashscope_client import DashscopeClient +from api.litellm_client import LiteLLMClient from api.rag import RAG # Configure logging @@ -560,6 +561,23 @@ async def handle_websocket_chat(websocket: WebSocket): model_kwargs=model_kwargs, model_type=ModelType.LLM ) + elif request.provider == "litellm": + logger.info(f"Using LiteLLM with model: {request.model}") + + model = LiteLLMClient() + model_kwargs = { + "model": request.model, + "stream": True, + "temperature": model_config.get("temperature", 0.7), + } + if "top_p" in model_config: + model_kwargs["top_p"] = model_config["top_p"] + + api_kwargs = model.convert_inputs_to_api_kwargs( + input=prompt, + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) else: # Initialize Google Generative AI model model = genai.GenerativeModel( @@ -702,7 +720,28 @@ async def handle_websocket_chat(websocket: WebSocket): "DASHSCOPE_WORKSPACE_ID) environment variables with valid values." ) await websocket.send_text(error_msg) - # Close the WebSocket connection after sending the error message + await websocket.close() + elif request.provider == "litellm": + try: + logger.info("Making LiteLLM API call") + response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM) + async for chunk in response: + choices = getattr(chunk, "choices", []) + if len(choices) > 0: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text is not None: + await websocket.send_text(text) + await websocket.close() + except Exception as e_litellm: + logger.error(f"Error with LiteLLM API: {str(e_litellm)}") + error_msg = ( + f"\nError with LiteLLM API: {str(e_litellm)}\n\n" + "Please check that your provider API keys are set as environment variables " + "and that the model name uses LiteLLM format (provider/model-name)." + ) + await websocket.send_text(error_msg) await websocket.close() else: # Google Generative AI (default provider) @@ -875,6 +914,26 @@ async def handle_websocket_chat(websocket: WebSocket): "DASHSCOPE_WORKSPACE_ID) environment variables with valid values." ) await websocket.send_text(error_msg) + elif request.provider == "litellm": + try: + fallback_api_kwargs = model.convert_inputs_to_api_kwargs( + input=simplified_prompt, + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + logger.info("Making fallback LiteLLM API call") + fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM) + async for chunk in fallback_response: + choices = getattr(chunk, "choices", []) + if len(choices) > 0: + delta = getattr(choices[0], "delta", None) + if delta is not None: + text = getattr(delta, "content", None) + if text is not None: + await websocket.send_text(text) + except Exception as e_fallback: + logger.error(f"Error with LiteLLM API fallback: {str(e_fallback)}") + await websocket.send_text(f"\nError with LiteLLM API fallback: {str(e_fallback)}") else: # Google Generative AI fallback (default provider) model_config = get_model_config(request.provider, request.model) diff --git a/tests/unit/test_litellm_client.py b/tests/unit/test_litellm_client.py index 7dc2732a4..eae4672e8 100644 --- a/tests/unit/test_litellm_client.py +++ b/tests/unit/test_litellm_client.py @@ -1,13 +1,16 @@ """Tests for the LiteLLM ModelClient integration.""" +import asyncio +import json import sys import types +from pathlib import Path from unittest import mock import pytest sys.path.insert(0, ".") -from api.litellm_client import LiteLLMClient +from api.litellm_client import LiteLLMClient, _is_retryable, handle_streaming_response from adalflow.core.types import ModelType, CompletionUsage @@ -44,11 +47,36 @@ def _make_mock_embedding_response(dims=1536): return resp +def _make_mock_stream_chunks(text="Hello"): + """Build mock streaming chunks.""" + chunks = [] + for char in text: + delta = mock.MagicMock() + delta.content = char + choice = mock.MagicMock() + choice.delta = delta + choice.finish_reason = None + chunk = mock.MagicMock() + chunk.choices = [choice] + chunks.append(chunk) + + delta_final = mock.MagicMock() + delta_final.content = None + choice_final = mock.MagicMock() + choice_final.delta = delta_final + choice_final.finish_reason = "stop" + chunk_final = mock.MagicMock() + chunk_final.choices = [choice_final] + chunks.append(chunk_final) + return chunks + + class TestLiteLLMClientInit: def test_default_init(self): client = LiteLLMClient() assert client._api_key is None assert client._base_url is None + assert client._input_type == "text" assert client.sync_client is not None def test_init_with_params(self): @@ -56,6 +84,10 @@ def test_init_with_params(self): assert client._api_key == "test-key" assert client._base_url == "https://proxy.example.com" + def test_init_with_messages_input_type(self): + client = LiteLLMClient(input_type="messages") + assert client._input_type == "messages" + class TestConvertInputs: def test_llm_string_input(self): @@ -80,7 +112,32 @@ def test_llm_message_list_input(self): model_type=ModelType.LLM, ) assert kwargs["messages"] == msgs - assert kwargs["model"] == "anthropic/claude-sonnet-4-20250514" + + def test_llm_tagged_messages_input(self): + client = LiteLLMClient(input_type="messages") + tagged = ( + "\nYou are helpful.\n\n" + "\nWhat is 2+2?\n" + ) + kwargs = client.convert_inputs_to_api_kwargs( + input=tagged, + model_kwargs={"model": "openai/gpt-4o"}, + model_type=ModelType.LLM, + ) + assert len(kwargs["messages"]) == 2 + assert kwargs["messages"][0]["role"] == "system" + assert "helpful" in kwargs["messages"][0]["content"] + assert kwargs["messages"][1]["role"] == "user" + assert "2+2" in kwargs["messages"][1]["content"] + + def test_llm_tagged_messages_no_match_falls_back(self): + client = LiteLLMClient(input_type="messages") + kwargs = client.convert_inputs_to_api_kwargs( + input="plain text no tags", + model_kwargs={"model": "openai/gpt-4o"}, + model_type=ModelType.LLM, + ) + assert kwargs["messages"] == [{"role": "user", "content": "plain text no tags"}] def test_embedder_string_input(self): client = LiteLLMClient() @@ -107,6 +164,13 @@ def test_unsupported_model_type(self): input="x", model_kwargs={}, model_type=ModelType.IMAGE_GENERATION ) + def test_none_model_kwargs_handled(self): + client = LiteLLMClient() + kwargs = client.convert_inputs_to_api_kwargs( + input="hello", model_kwargs=None, model_type=ModelType.LLM + ) + assert kwargs["messages"] == [{"role": "user", "content": "hello"}] + class TestCallMocked: def test_completion_dispatches_correctly(self): @@ -116,8 +180,6 @@ def test_completion_dispatches_correctly(self): fake_litellm = types.ModuleType("litellm") fake_litellm.completion = mock.MagicMock(return_value=mock_resp) fake_litellm.embedding = mock.MagicMock() - fake_litellm.acompletion = mock.AsyncMock() - fake_litellm.aembedding = mock.AsyncMock() with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): kwargs = { @@ -137,74 +199,161 @@ def test_embedding_dispatches_correctly(self): mock_resp = _make_mock_embedding_response() fake_litellm = types.ModuleType("litellm") - fake_litellm.completion = mock.MagicMock() fake_litellm.embedding = mock.MagicMock(return_value=mock_resp) with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): - kwargs = { - "model": "text-embedding-3-small", - "input": ["hello"], - } + kwargs = {"model": "text-embedding-3-small", "input": ["hello"]} result = client.call(api_kwargs=kwargs, model_type=ModelType.EMBEDDER) fake_litellm.embedding.assert_called_once() - emb_call_kwargs = fake_litellm.embedding.call_args - assert emb_call_kwargs.kwargs["drop_params"] is True + assert fake_litellm.embedding.call_args.kwargs["drop_params"] is True def test_api_key_forwarded_when_set(self): client = LiteLLMClient(api_key="sk-test123") mock_resp = _make_mock_response() - fake_litellm = types.ModuleType("litellm") fake_litellm.completion = mock.MagicMock(return_value=mock_resp) with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): - kwargs = { - "model": "openai/gpt-4o", - "messages": [{"role": "user", "content": "hi"}], - } - client.call(api_kwargs=kwargs, model_type=ModelType.LLM) - call_kwargs = fake_litellm.completion.call_args - assert call_kwargs.kwargs["api_key"] == "sk-test123" + client.call( + api_kwargs={"model": "x", "messages": [{"role": "user", "content": "hi"}]}, + model_type=ModelType.LLM, + ) + assert fake_litellm.completion.call_args.kwargs["api_key"] == "sk-test123" def test_api_key_omitted_when_blank(self): client = LiteLLMClient() mock_resp = _make_mock_response() - fake_litellm = types.ModuleType("litellm") fake_litellm.completion = mock.MagicMock(return_value=mock_resp) with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): - kwargs = { - "model": "openai/gpt-4o", - "messages": [{"role": "user", "content": "hi"}], - } - client.call(api_kwargs=kwargs, model_type=ModelType.LLM) - call_kwargs = fake_litellm.completion.call_args - assert "api_key" not in call_kwargs.kwargs + client.call( + api_kwargs={"model": "x", "messages": [{"role": "user", "content": "hi"}]}, + model_type=ModelType.LLM, + ) + assert "api_key" not in fake_litellm.completion.call_args.kwargs def test_base_url_forwarded_when_set(self): client = LiteLLMClient(base_url="https://proxy.local") mock_resp = _make_mock_response() - fake_litellm = types.ModuleType("litellm") fake_litellm.completion = mock.MagicMock(return_value=mock_resp) with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): - kwargs = { - "model": "openai/gpt-4o", - "messages": [{"role": "user", "content": "hi"}], - } - client.call(api_kwargs=kwargs, model_type=ModelType.LLM) - call_kwargs = fake_litellm.completion.call_args - assert call_kwargs.kwargs["api_base"] == "https://proxy.local" + client.call( + api_kwargs={"model": "x", "messages": [{"role": "user", "content": "hi"}]}, + model_type=ModelType.LLM, + ) + assert fake_litellm.completion.call_args.kwargs["api_base"] == "https://proxy.local" + + def test_unsupported_model_type_in_call(self): + client = LiteLLMClient() + fake_litellm = types.ModuleType("litellm") + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + with pytest.raises(ValueError, match="not supported"): + client.call(api_kwargs={}, model_type=ModelType.IMAGE_GENERATION) + + def test_streaming_call(self): + client = LiteLLMClient() + chunks = _make_mock_stream_chunks("Hi") + fake_litellm = types.ModuleType("litellm") + fake_litellm.completion = mock.MagicMock(return_value=iter(chunks)) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + result = client.call( + api_kwargs={"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True}, + model_type=ModelType.LLM, + ) + assert fake_litellm.completion.call_args.kwargs["stream"] is True + + +class TestAcallMocked: + def test_acall_completion(self): + client = LiteLLMClient() + mock_resp = _make_mock_response("async response") + fake_litellm = types.ModuleType("litellm") + fake_litellm.acompletion = mock.AsyncMock(return_value=mock_resp) + fake_litellm.aembedding = mock.AsyncMock() + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + result = asyncio.get_event_loop().run_until_complete( + client.acall( + api_kwargs={"model": "x", "messages": [{"role": "user", "content": "hi"}]}, + model_type=ModelType.LLM, + ) + ) + fake_litellm.acompletion.assert_called_once() + assert fake_litellm.acompletion.call_args.kwargs["drop_params"] is True + assert result.choices[0].message.content == "async response" + + def test_acall_embedding(self): + client = LiteLLMClient() + mock_resp = _make_mock_embedding_response() + fake_litellm = types.ModuleType("litellm") + fake_litellm.aembedding = mock.AsyncMock(return_value=mock_resp) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + result = asyncio.get_event_loop().run_until_complete( + client.acall( + api_kwargs={"model": "text-embedding-3-small", "input": ["hello"]}, + model_type=ModelType.EMBEDDER, + ) + ) + fake_litellm.aembedding.assert_called_once() + + def test_acall_api_key_forwarded(self): + client = LiteLLMClient(api_key="sk-async-key") + mock_resp = _make_mock_response() + fake_litellm = types.ModuleType("litellm") + fake_litellm.acompletion = mock.AsyncMock(return_value=mock_resp) + + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + asyncio.get_event_loop().run_until_complete( + client.acall( + api_kwargs={"model": "x", "messages": [{"role": "user", "content": "hi"}]}, + model_type=ModelType.LLM, + ) + ) + assert fake_litellm.acompletion.call_args.kwargs["api_key"] == "sk-async-key" + + def test_acall_unsupported_model_type(self): + client = LiteLLMClient() + fake_litellm = types.ModuleType("litellm") + with mock.patch.dict(sys.modules, {"litellm": fake_litellm}): + with pytest.raises(ValueError, match="not supported"): + asyncio.get_event_loop().run_until_complete( + client.acall(api_kwargs={}, model_type=ModelType.IMAGE_GENERATION) + ) + + +class TestStreamingHandler: + def test_handle_streaming_response(self): + chunks = _make_mock_stream_chunks("OK") + result = list(handle_streaming_response(iter(chunks))) + assert "".join(result) == "OK" + + def test_handle_streaming_with_none_content(self): + delta = mock.MagicMock() + delta.content = None + choice = mock.MagicMock() + choice.delta = delta + chunk = mock.MagicMock() + chunk.choices = [choice] + result = list(handle_streaming_response(iter([chunk]))) + assert result == [] + + def test_handle_streaming_with_empty_choices(self): + chunk = mock.MagicMock() + chunk.choices = [] + result = list(handle_streaming_response(iter([chunk]))) + assert result == [] class TestParseCompletion: def test_parse_chat_completion(self): client = LiteLLMClient() mock_resp = _make_mock_response("Hello world", 10, 5) - output = client.parse_chat_completion(mock_resp) assert output.raw_response == "Hello world" assert output.usage.completion_tokens == 5 @@ -213,12 +362,41 @@ def test_parse_chat_completion(self): def test_track_usage(self): client = LiteLLMClient() mock_resp = _make_mock_response("x", 20, 30) - usage = client.track_completion_usage(mock_resp) assert usage.prompt_tokens == 20 assert usage.completion_tokens == 30 assert usage.total_tokens == 50 + def test_track_usage_missing_usage(self): + client = LiteLLMClient() + mock_resp = mock.MagicMock(spec=[]) + usage = client.track_completion_usage(mock_resp) + assert usage.completion_tokens is None + + def test_parse_error_in_parser(self): + client = LiteLLMClient(chat_completion_parser=lambda c: 1 / 0) + mock_resp = _make_mock_response() + output = client.parse_chat_completion(mock_resp) + assert output.error is not None + assert "division by zero" in output.error + + +class TestRetryPredicate: + def test_rate_limit_is_retryable(self): + exc = type("RateLimitError", (Exception,), {})() + exc.__class__.__module__ = "litellm.exceptions" + exc.__class__.__qualname__ = "RateLimitError" + assert _is_retryable(exc) + + def test_auth_error_is_not_retryable(self): + exc = type("AuthenticationError", (Exception,), {})() + exc.__class__.__module__ = "litellm.exceptions" + exc.__class__.__qualname__ = "AuthenticationError" + assert not _is_retryable(exc) + + def test_value_error_is_not_retryable(self): + assert not _is_retryable(ValueError("bad model")) + class TestSerialization: def test_from_dict(self): @@ -236,14 +414,10 @@ class TestConfigRegistration: def test_litellm_in_client_classes(self): pytest.importorskip("boto3") from api.config import CLIENT_CLASSES - assert "LiteLLMClient" in CLIENT_CLASSES assert CLIENT_CLASSES["LiteLLMClient"] is LiteLLMClient def test_litellm_provider_in_generator_config(self): - import json - from pathlib import Path - config_path = Path("api/config/generator.json") config = json.loads(config_path.read_text()) assert "litellm" in config["providers"] From 0e3789dad267e7cb1514dcbca4bcc1957b40488a Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Thu, 30 Apr 2026 00:35:45 +0530 Subject: [PATCH 3/4] fix: address review feedback --- api/config.py | 22 +++++++++++----------- api/litellm_client.py | 21 +++++++++------------ tests/unit/test_litellm_client.py | 3 ++- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/api/config.py b/api/config.py index 2ca7b7007..7c729ad9d 100644 --- a/api/config.py +++ b/api/config.py @@ -127,23 +127,23 @@ def load_generator_config(): generator_config = load_json_config("generator.json") # Add client classes to each provider + default_map = { + "google": GoogleGenAIClient, + "openai": OpenAIClient, + "openrouter": OpenRouterClient, + "ollama": OllamaClient, + "bedrock": BedrockClient, + "azure": AzureAIClient, + "dashscope": DashscopeClient, + "litellm": LiteLLMClient, + } if "providers" in generator_config: for provider_id, provider_config in generator_config["providers"].items(): # Try to set client class from client_class if provider_config.get("client_class") in CLIENT_CLASSES: provider_config["model_client"] = CLIENT_CLASSES[provider_config["client_class"]] # Fall back to default mapping based on provider_id - elif provider_id in ["google", "openai", "openrouter", "ollama", "bedrock", "azure", "dashscope", "litellm"]: - default_map = { - "google": GoogleGenAIClient, - "openai": OpenAIClient, - "openrouter": OpenRouterClient, - "ollama": OllamaClient, - "bedrock": BedrockClient, - "azure": AzureAIClient, - "dashscope": DashscopeClient, - "litellm": LiteLLMClient, - } + elif provider_id in default_map: provider_config["model_client"] = default_map[provider_id] else: logger.warning(f"Unknown provider or client class: {provider_id}") diff --git a/api/litellm_client.py b/api/litellm_client.py index 586ac5203..03ca4a9dc 100644 --- a/api/litellm_client.py +++ b/api/litellm_client.py @@ -167,20 +167,19 @@ def convert_inputs_to_api_kwargs( return final_model_kwargs def parse_chat_completion(self, completion) -> GeneratorOutput: - try: - data = self.chat_completion_parser(completion) - except Exception as e: - log.error(f"Error parsing the completion: {e}") - return GeneratorOutput(data=None, error=str(e), raw_response=completion) + import types try: - usage = self.track_completion_usage(completion) + is_stream = isinstance(completion, types.GeneratorType) or type(completion).__name__ == "CustomStreamWrapper" + parser = handle_streaming_response if is_stream else self.chat_completion_parser + parsed_data = parser(completion) + usage = None if is_stream else self.track_completion_usage(completion) return GeneratorOutput( - data=None, error=None, raw_response=data, usage=usage + data=parsed_data, error=None, raw_response=completion, usage=usage ) except Exception as e: - log.error(f"Error tracking the completion usage: {e}") - return GeneratorOutput(data=None, error=str(e), raw_response=data) + log.error(f"Error in parse_chat_completion: {e}") + return GeneratorOutput(data=None, error=str(e), raw_response=completion) def track_completion_usage(self, completion) -> CompletionUsage: try: @@ -207,7 +206,7 @@ def call(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelT import litellm api_kwargs = api_kwargs or {} - log.info(f"api_kwargs: {api_kwargs}") + log.debug(f"api_kwargs: {api_kwargs}") extra: Dict[str, Any] = {} if self._api_key: @@ -218,8 +217,6 @@ def call(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelT if model_type == ModelType.EMBEDDER: return litellm.embedding(drop_params=True, **api_kwargs, **extra) elif model_type == ModelType.LLM: - if api_kwargs.get("stream", False): - self.chat_completion_parser = handle_streaming_response return litellm.completion(drop_params=True, **api_kwargs, **extra) else: raise ValueError(f"model_type {model_type} is not supported") diff --git a/tests/unit/test_litellm_client.py b/tests/unit/test_litellm_client.py index eae4672e8..0f160810e 100644 --- a/tests/unit/test_litellm_client.py +++ b/tests/unit/test_litellm_client.py @@ -355,7 +355,8 @@ def test_parse_chat_completion(self): client = LiteLLMClient() mock_resp = _make_mock_response("Hello world", 10, 5) output = client.parse_chat_completion(mock_resp) - assert output.raw_response == "Hello world" + assert output.data == "Hello world" + assert output.raw_response == mock_resp assert output.usage.completion_tokens == 5 assert output.usage.prompt_tokens == 10 From d543ac63e05ea7297ce72337d439cdb76cdef9d9 Mon Sep 17 00:00:00 2001 From: Aarish Alam Date: Thu, 7 May 2026 01:31:01 +0530 Subject: [PATCH 4/4] fix: guard backoff import, fix duplicate kwargs, use specific retry exceptions --- api/config/generator.json | 2 +- api/litellm_client.py | 85 +++++++++++++++++++++++++-------------- api/pyproject.toml | 1 + 3 files changed, 57 insertions(+), 31 deletions(-) diff --git a/api/config/generator.json b/api/config/generator.json index 0e6b774e8..f8e739000 100644 --- a/api/config/generator.json +++ b/api/config/generator.json @@ -196,7 +196,7 @@ }, "litellm": { "client_class": "LiteLLMClient", - "default_model": "openai/gpt-4o", + "default_model": "openai/gpt-4o-mini", "supportsCustomModel": true, "models": { "openai/gpt-4o": { diff --git a/api/litellm_client.py b/api/litellm_client.py index 03ca4a9dc..10b3aa5a8 100644 --- a/api/litellm_client.py +++ b/api/litellm_client.py @@ -13,6 +13,7 @@ import logging import re +import types from typing import ( Any, Callable, @@ -20,11 +21,13 @@ List, Optional, Sequence, - TypeVar, - Union, ) -import backoff +try: + import backoff + _BACKOFF_AVAILABLE = True +except ImportError: + _BACKOFF_AVAILABLE = False from adalflow.core.model_client import ModelClient from adalflow.core.types import ( @@ -36,18 +39,42 @@ from adalflow.components.model_client.utils import parse_embedding_response log = logging.getLogger(__name__) -T = TypeVar("T") -def _is_retryable(exc: BaseException) -> bool: - qualname = f"{type(exc).__module__}.{type(exc).__name__}" - return qualname in { - "litellm.exceptions.RateLimitError", - "litellm.exceptions.ServiceUnavailableError", - "litellm.exceptions.Timeout", - "litellm.exceptions.APIConnectionError", - "litellm.exceptions.InternalServerError", - } +try: + import litellm as _litellm_mod + _RETRYABLE = ( + _litellm_mod.exceptions.RateLimitError, + _litellm_mod.exceptions.ServiceUnavailableError, + _litellm_mod.exceptions.Timeout, + _litellm_mod.exceptions.APIConnectionError, + _litellm_mod.exceptions.InternalServerError, + ) + + def _is_retryable(exc: BaseException) -> bool: + return isinstance(exc, _RETRYABLE) +except (ImportError, AttributeError): + _RETRYABLE = (Exception,) + + def _is_retryable(exc: BaseException) -> bool: + return False + + +def _with_retry(fn): + """Apply exponential backoff retry for transient LiteLLM errors. + + Note: retry only fires for non-streaming calls. When stream=True, + litellm returns a generator immediately (no exception at the call site), + so mid-stream failures are not retried by this decorator. + """ + if _BACKOFF_AVAILABLE: + return backoff.on_exception( + backoff.expo, + _RETRYABLE, + max_time=60, + giveup=lambda e: not _is_retryable(e), + )(fn) + return fn def get_first_message_content(completion) -> str: @@ -111,14 +138,6 @@ def __init__( self.chat_completion_parser = ( chat_completion_parser or get_first_message_content ) - self.sync_client = self.init_sync_client() - self.async_client = None - - def init_sync_client(self): - return {"api_key": self._api_key, "base_url": self._base_url} - - def init_async_client(self): - return {"api_key": self._api_key, "base_url": self._base_url} def convert_inputs_to_api_kwargs( self, @@ -167,10 +186,10 @@ def convert_inputs_to_api_kwargs( return final_model_kwargs def parse_chat_completion(self, completion) -> GeneratorOutput: - import types - try: - is_stream = isinstance(completion, types.GeneratorType) or type(completion).__name__ == "CustomStreamWrapper" + is_stream = isinstance( + completion, (types.GeneratorType, types.AsyncGeneratorType) + ) or type(completion).__name__ == "CustomStreamWrapper" parser = handle_streaming_response if is_stream else self.chat_completion_parser parsed_data = parser(completion) usage = None if is_stream else self.track_completion_usage(completion) @@ -201,7 +220,7 @@ def parse_embedding_response(self, response) -> EmbedderOutput: log.error(f"Error parsing the embedding response: {e}") return EmbedderOutput(data=[], error=str(e), raw_response=response) - @backoff.on_exception(backoff.expo, Exception, max_time=5, giveup=lambda e: not _is_retryable(e)) + @_with_retry def call(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelType.UNDEFINED): import litellm @@ -214,14 +233,16 @@ def call(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelT if self._base_url: extra["api_base"] = self._base_url + final_kwargs = {"drop_params": True, **api_kwargs, **extra} + if model_type == ModelType.EMBEDDER: - return litellm.embedding(drop_params=True, **api_kwargs, **extra) + return litellm.embedding(**final_kwargs) elif model_type == ModelType.LLM: - return litellm.completion(drop_params=True, **api_kwargs, **extra) + return litellm.completion(**final_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") - @backoff.on_exception(backoff.expo, Exception, max_time=5, giveup=lambda e: not _is_retryable(e)) + @_with_retry async def acall(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = ModelType.UNDEFINED): import litellm @@ -233,15 +254,19 @@ async def acall(self, api_kwargs: Optional[Dict] = None, model_type: ModelType = if self._base_url: extra["api_base"] = self._base_url + final_kwargs = {"drop_params": True, **api_kwargs, **extra} + if model_type == ModelType.EMBEDDER: - return await litellm.aembedding(drop_params=True, **api_kwargs, **extra) + return await litellm.aembedding(**final_kwargs) elif model_type == ModelType.LLM: - return await litellm.acompletion(drop_params=True, **api_kwargs, **extra) + return await litellm.acompletion(**final_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") @classmethod def from_dict(cls, data: Dict[str, Any]): + """Deserialize from dict. Note: chat_completion_parser is not restored + since callables cannot be JSON-serialized.""" return cls(**data) def to_dict(self) -> Dict[str, Any]: diff --git a/api/pyproject.toml b/api/pyproject.toml index 3f8328db5..d9f358adb 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -15,6 +15,7 @@ pydantic = ">=2.0.0" google-generativeai = ">=0.3.0" tiktoken = ">=0.5.0" adalflow = ">=0.1.0" +backoff = ">=2.2.1,<3.0.0" numpy = ">=1.24.0" faiss-cpu = ">=1.7.4" langid = ">=1.1.6"