diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 000000000..1dc19e8da
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,10 @@
+{
+ "python.defaultInterpreterPath": "${workspaceFolder}/api/.venv/bin/python",
+ "python.analysis.extraPaths": [
+ "${workspaceFolder}/api"
+ ],
+ "python.autoComplete.extraPaths": [
+ "${workspaceFolder}/api"
+ ]
+}
+
diff --git a/api/config.py b/api/config.py
index 49dfcf7b0..2872646ad 100644
--- a/api/config.py
+++ b/api/config.py
@@ -284,6 +284,45 @@ def load_lang_config():
return loaded_config
+# Load API keys configuration
+def load_api_keys_config():
+ """
+ Load API keys configuration, support configuration file and environment variable fallback
+
+ Returns format:
+ {
+ "google": ["key1", "key2"],
+ "openai": ["key1"],
+ ...
+ }
+ """
+ api_keys_config = load_json_config("api_keys.json")
+
+ result = {}
+ for provider_id, config in api_keys_config.items():
+ keys = config.get("keys", [])
+
+ # If keys are not present in the configuration file, try reading from environment variables
+ if not keys:
+ env_var_name = f"{provider_id.upper()}_API_KEYS"
+ env_value = os.environ.get(env_var_name)
+ if env_value:
+ keys = [k.strip() for k in env_value.split(',')]
+ else:
+ # If keys array element contains comma, split it into multiple keys
+ expanded_keys = []
+ for key in keys:
+ if isinstance(key, str) and ',' in key:
+ # This is a comma-separated string, split it
+ expanded_keys.extend([k.strip() for k in key.split(',') if k.strip()])
+ elif key: # Ignore empty strings
+ expanded_keys.append(key)
+ keys = expanded_keys
+
+ result[provider_id] = keys
+
+ return result
+
# Default excluded directories and files
DEFAULT_EXCLUDED_DIRS: List[str] = [
# Virtual environments and package managers
@@ -333,6 +372,7 @@ def load_lang_config():
embedder_config = load_embedder_config()
repo_config = load_repo_config()
lang_config = load_lang_config()
+api_keys_config = load_api_keys_config()
# Update configuration
if generator_config:
@@ -355,6 +395,10 @@ def load_lang_config():
if lang_config:
configs["lang_config"] = lang_config
+# Update API keys configuration
+if api_keys_config:
+ configs["api_keys"] = api_keys_config
+
def get_model_config(provider="google", model=None):
"""
diff --git a/api/config/api_keys.json b/api/config/api_keys.json
new file mode 100644
index 000000000..a826c1a83
--- /dev/null
+++ b/api/config/api_keys.json
@@ -0,0 +1,20 @@
+{
+ "google": {
+ "keys": ["${GOOGLE_API_KEYS}"]
+ },
+ "openai": {
+ "keys": ["${OPENAI_API_KEYS}"]
+ },
+ "openrouter": {
+ "keys": ["${OPENROUTER_API_KEYS}"]
+ },
+ "azure": {
+ "keys": ["${AZURE_OPENAI_API_KEYS}"]
+ },
+ "bedrock": {},
+ "dashscope": {
+ "keys": ["${DASHSCOPE_API_KEYS}"]
+ },
+ "ollama": {}
+}
+
diff --git a/api/config/generator.json b/api/config/generator.json
index f88179098..2245f9f4a 100644
--- a/api/config/generator.json
+++ b/api/config/generator.json
@@ -41,7 +41,7 @@
}
},
"openai": {
- "default_model": "gpt-5-nano",
+ "default_model": "gpt-4.1",
"supportsCustomModel": true,
"models": {
"gpt-5": {
diff --git a/api/llm.py b/api/llm.py
new file mode 100644
index 000000000..eacfcd571
--- /dev/null
+++ b/api/llm.py
@@ -0,0 +1,702 @@
+import os
+import logging
+import asyncio
+import time
+import uuid
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Dict, Any, Optional, List
+from dotenv import load_dotenv
+
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# Load environment variables
+load_dotenv()
+
+
+class LLMService:
+ """Service Layer for managing LLM API calls with multi-provider support.
+
+ Supports load balancing across multiple API keys for each provider.
+ Compatible with all adalflow.ModelClient implementations.
+ """
+
+ # Singleton instance
+ _instance = None
+
+ def __new__(cls, default_provider: str = "google"):
+ if cls._instance is None:
+ cls._instance = super(LLMService, cls).__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self, default_provider: str = "google"):
+ """
+ Initialize the LLM service
+
+ Args:
+ default_provider: The default provider to use (google/openai/openrouter/azure/bedrock/dashscope/ollama)
+ """
+ if self._initialized:
+ return
+
+ # Load from configuration
+ from api.config import configs
+
+ self.default_provider = default_provider
+ self.configs = configs
+
+ # Initialize all API keys for all providers
+ self._init_all_provider_keys()
+
+ # Initialize the client instance cache for each provider
+ self.client_cache = {}
+
+ # Usage statistics grouped by provider
+ self.provider_key_usage = {}
+ self.provider_key_last_used = {}
+
+ for provider in self.api_keys_by_provider:
+ keys = self.api_keys_by_provider[provider]
+ self.provider_key_usage[provider] = {str(k): 0 for k in keys}
+ self.provider_key_last_used[provider] = {str(k): 0 for k in keys}
+
+ # Thread pool for concurrent requests
+ self.thread_pool = ThreadPoolExecutor(max_workers=20)
+
+ self._initialized = True
+ logger.info(f"LLMService initialized with default provider: {default_provider}")
+
+ def _init_all_provider_keys(self):
+ """Load all API keys for all providers from the configuration"""
+ self.api_keys_by_provider = self.configs.get("api_keys", {})
+
+ # For backward compatibility, keep self.api_keys pointing to the keys of the default provider
+ self.api_keys = self.api_keys_by_provider.get(self.default_provider, [])
+
+ logger.info(f"Loaded API keys for providers: {list(self.api_keys_by_provider.keys())}")
+ for provider, keys in self.api_keys_by_provider.items():
+ logger.info(f" {provider}: {len(keys)} key(s)")
+
+ def _get_client(self, provider: str, api_key: Optional[str] = None):
+ """
+ Get the client instance for the specified provider
+
+ Args:
+ provider: Provider name
+ api_key: Optional API key (for different keys when load balancing)
+
+ Returns:
+ ModelClient Instance
+ """
+ from api.config import get_model_config
+
+ # Generate cache key
+ cache_key = f"{provider}_{api_key[:8] if api_key else 'default'}"
+
+ if cache_key in self.client_cache:
+ logger.debug(f"Using cached client for {provider}")
+ return self.client_cache[cache_key]
+
+ logger.info(f"Creating new client for provider: {provider}")
+
+ # Get model config
+ model_config = get_model_config(provider)
+ model_client_class = model_config["model_client"]
+
+ # Initialize client for different providers
+ if provider == "openai":
+ client = model_client_class(api_key=api_key)
+ elif provider == "google":
+ # Google client reads key from environment variables
+ if api_key:
+ os.environ["GOOGLE_API_KEY"] = api_key
+ client = model_client_class()
+ elif provider == "openrouter":
+ client = model_client_class(api_key=api_key)
+ elif provider == "azure":
+ client = model_client_class() # Azure reads key from environment variables
+ elif provider == "bedrock":
+ client = model_client_class() # Bedrock uses AWS credentials
+ elif provider == "dashscope":
+ if api_key:
+ os.environ["DASHSCOPE_API_KEY"] = api_key
+ client = model_client_class()
+ elif provider == "ollama":
+ client = model_client_class() # Ollama local service
+ else:
+ raise ValueError(f"Unsupported provider: {provider}")
+
+ self.client_cache[cache_key] = client
+ logger.info(f"Client created and cached for {provider}")
+ return client
+
+ def get_next_api_key(self, provider: Optional[str] = None) -> Optional[str]:
+ """
+ Get the next available API key for the specified provider (load balancing)
+
+ Args:
+ provider: Provider name, default uses self.default_provider
+
+ Returns:
+ API key or None (for providers that do not require a key)
+ """
+ provider = provider or self.default_provider
+
+ keys = self.api_keys_by_provider.get(provider, [])
+ if not keys:
+ # Some providers (e.g. ollama, bedrock) do not require an API key
+ logger.debug(f"No API keys configured for provider: {provider}")
+ return None
+
+ if len(keys) == 1:
+ return keys[0]
+
+ # Load balancing logic: select the key with the least usage and the least recently used
+ current_time = time.time()
+ best_key = min(
+ keys,
+ key=lambda k: (
+ self.provider_key_usage[provider][str(k)],
+ self.provider_key_last_used[provider][str(k)]
+ )
+ )
+
+ # 更新统计
+ best_key_str = str(best_key)
+ self.provider_key_usage[provider][best_key_str] += 1
+ self.provider_key_last_used[provider][best_key_str] = current_time
+
+ logger.debug(f"Selected API key for {provider}: {best_key[:8]}...{best_key[-4:]}")
+ return best_key
+
+ def reset_key_usage_stats(self, provider: Optional[str] = None):
+ """
+ Reset key usage statistics.
+
+ Args:
+ provider: Provider name to reset stats for, or None to reset all
+ """
+ if provider:
+ if provider in self.provider_key_usage:
+ self.provider_key_usage[provider] = {str(k): 0 for k in self.api_keys_by_provider[provider]}
+ self.provider_key_last_used[provider] = {str(k): 0 for k in self.api_keys_by_provider[provider]}
+ logger.info(f"Key usage statistics reset for provider: {provider}")
+ else:
+ for prov in self.provider_key_usage:
+ self.provider_key_usage[prov] = {str(k): 0 for k in self.api_keys_by_provider[prov]}
+ self.provider_key_last_used[prov] = {str(k): 0 for k in self.api_keys_by_provider[prov]}
+ logger.info("Key usage statistics reset for all providers")
+
+ def direct_invoke(
+ self,
+ prompt: str,
+ provider: Optional[str] = None,
+ model: Optional[str] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ stream: bool = False,
+ api_key: Optional[str] = None
+ ):
+ """
+ Unified synchronous/streaming call interface
+
+ Args:
+ prompt: User prompt
+ provider: Provider name (higher priority than default_provider during initialization)
+ model: Model name
+ temperature: Temperature parameter
+ max_tokens: Maximum token count
+ stream: Whether to stream output
+ api_key: Specified API key (if not specified, use load balancing)
+
+ Returns:
+ Non-streaming: Dict[str, Any] containing response content
+ Streaming: Stream object can be iterated
+ """
+ from api.config import get_model_config
+ from adalflow.core.types import ModelType
+
+ # Determine provider
+ provider = provider or self.default_provider
+
+ # Get API key
+ if not api_key:
+ api_key = self.get_next_api_key(provider)
+
+ # Get client
+ client = self._get_client(provider, api_key)
+
+ # Get model config
+ model_config = get_model_config(provider, model)
+ model_kwargs = model_config["model_kwargs"].copy()
+
+ # Override parameters
+ if temperature is not None:
+ model_kwargs["temperature"] = temperature
+ if max_tokens is not None:
+ model_kwargs["max_tokens"] = max_tokens
+ model_kwargs["stream"] = stream
+
+ # Generate request_id
+ request_id = str(uuid.uuid4())
+
+ logger.info(f"[{request_id}] Provider: {provider}, Model: {model_kwargs.get('model', 'N/A')}, Stream: {stream}")
+ if api_key:
+ logger.info(f"[{request_id}] Using API key: {api_key[:8]}...{api_key[-4:]}")
+
+ try:
+ # Convert input to API parameters
+ api_kwargs = client.convert_inputs_to_api_kwargs(
+ input=prompt,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+
+ # Call client
+ response = client.call(api_kwargs=api_kwargs, model_type=ModelType.LLM)
+
+ # If streaming, return directly
+ if stream:
+ logger.info(f"[{request_id}] Returning stream response")
+ return response
+
+ # Non-streaming, parse response
+ parsed_response = client.parse_chat_completion(response)
+
+ content = parsed_response.raw_response if hasattr(parsed_response, 'raw_response') else str(parsed_response)
+ logger.info(f"[{request_id}] Response received: {len(str(content))} characters")
+
+ return {
+ "content": content,
+ "model": model_kwargs.get("model", "N/A"),
+ "provider": provider,
+ "request_id": request_id,
+ "api_key_used": f"{api_key[:8]}...{api_key[-4:]}" if api_key else "N/A"
+ }
+
+ except Exception as e:
+ logger.error(f"[{request_id}] Error: {str(e)}", exc_info=True)
+ raise RuntimeError(f"API call failed for provider {provider}: {str(e)}")
+
+ def direct_invoke_with_system(
+ self,
+ prompt: str,
+ system_message: str,
+ provider: Optional[str] = None,
+ model: Optional[str] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ stream: bool = False,
+ api_key: Optional[str] = None
+ ):
+ """
+ Call with system message (adapting to adalflow's messages format)
+
+ Args:
+ prompt: User prompt
+ system_message: System message
+ provider: Provider name
+ model: Model name
+ temperature: Temperature parameter
+ max_tokens: Maximum token count
+ stream: Whether to stream output
+ api_key: Specified API key
+
+ Returns:
+ Non-streaming: Dict[str, Any]
+ Streaming: Stream object
+ """
+ from api.config import get_model_config
+ from adalflow.core.types import ModelType
+
+ # Determine provider
+ provider = provider or self.default_provider
+
+ # Get API key
+ if not api_key:
+ api_key = self.get_next_api_key(provider)
+
+ # Get client
+ client = self._get_client(provider, api_key)
+
+ # Get model config
+ model_config = get_model_config(provider, model)
+ model_kwargs = model_config["model_kwargs"].copy()
+
+ # Override parameters
+ if temperature is not None:
+ model_kwargs["temperature"] = temperature
+ if max_tokens is not None:
+ model_kwargs["max_tokens"] = max_tokens
+ model_kwargs["stream"] = stream
+
+ # Generate request_id
+ request_id = str(uuid.uuid4())
+
+ logger.info(f"[{request_id}] Provider: {provider}, Model: {model_kwargs.get('model', 'N/A')}, With system message, Stream: {stream}")
+
+ try:
+ # Build messages format input
+ messages = [
+ {"role": "system", "content": system_message},
+ {"role": "user", "content": prompt}
+ ]
+
+ # Convert input to API parameters
+ api_kwargs = client.convert_inputs_to_api_kwargs(
+ input=messages,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+
+ # Call client
+ response = client.call(api_kwargs=api_kwargs, model_type=ModelType.LLM)
+
+ # If streaming, return directly
+ if stream:
+ logger.info(f"[{request_id}] Returning stream response")
+ return response
+
+ # Non-streaming, parse response
+ parsed_response = client.parse_chat_completion(response)
+
+ content = parsed_response.raw_response if hasattr(parsed_response, 'raw_response') else str(parsed_response)
+ logger.info(f"[{request_id}] Response received: {len(str(content))} characters")
+
+ return {
+ "content": content,
+ "model": model_kwargs.get("model", "N/A"),
+ "provider": provider,
+ "request_id": request_id,
+ "api_key_used": f"{api_key[:8]}...{api_key[-4:]}" if api_key else "N/A"
+ }
+
+ except Exception as e:
+ logger.error(f"[{request_id}] Error: {str(e)}", exc_info=True)
+ raise RuntimeError(f"API call failed for provider {provider}: {str(e)}")
+
+ async def async_invoke_stream(
+ self,
+ prompt: str,
+ provider: Optional[str] = None,
+ model: Optional[str] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ api_key: Optional[str] = None
+ ):
+ """
+ Asynchronous streaming call
+
+ Args:
+ prompt: User prompt
+ provider: Provider name
+ model: Model name
+ temperature: Temperature parameter
+ max_tokens: Maximum token count
+ api_key: Specified API key
+
+ Yields:
+ str: Text content of each chunk
+ """
+ from api.config import get_model_config
+ from adalflow.core.types import ModelType
+
+ provider = provider or self.default_provider
+
+ if not api_key:
+ api_key = self.get_next_api_key(provider)
+
+ client = self._get_client(provider, api_key)
+
+ # Get model config
+ model_config = get_model_config(provider, model)
+ model_kwargs = model_config["model_kwargs"].copy()
+
+ if temperature is not None:
+ model_kwargs["temperature"] = temperature
+ if max_tokens is not None:
+ model_kwargs["max_tokens"] = max_tokens
+ model_kwargs["stream"] = True
+
+ request_id = str(uuid.uuid4())
+ logger.info(f"[{request_id}] Async stream - Provider: {provider}, Model: {model_kwargs.get('model', 'N/A')}")
+
+ try:
+ # Convert input
+ api_kwargs = client.convert_inputs_to_api_kwargs(
+ input=prompt,
+ model_kwargs=model_kwargs,
+ model_type=ModelType.LLM
+ )
+
+ # Asynchronous call
+ response = await client.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
+
+ # Handle different provider's streaming response format
+ if provider == "google":
+ for chunk in response:
+ if hasattr(chunk, 'text'):
+ yield chunk.text
+ elif provider in ["openai", "openrouter", "azure"]:
+ async for chunk in response:
+ if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
+ delta = chunk.choices[0].delta
+ if hasattr(delta, 'content') and delta.content:
+ yield delta.content
+ elif provider == "ollama":
+ async for chunk in response:
+ text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None)
+ if text:
+ yield text
+ elif provider == "bedrock":
+ # Bedrock may not support streaming, return complete response
+ yield str(response)
+ elif provider == "dashscope":
+ async for text in response:
+ if text:
+ yield text
+ else:
+ logger.warning(f"Unknown provider {provider}, attempting generic streaming")
+ async for chunk in response:
+ yield str(chunk)
+
+ except Exception as e:
+ logger.error(f"[{request_id}] Async stream error: {str(e)}", exc_info=True)
+ raise RuntimeError(f"Async stream failed for provider {provider}: {str(e)}")
+
+ def parallel_invoke(
+ self,
+ requests: List[Dict[str, Any]],
+ max_concurrent_per_key: int = 3,
+ max_total_concurrent: int = 10,
+ timeout: float = 60.0
+ ) -> List[Dict[str, Any]]:
+ """
+ Parallel invoke multiple requests using available API keys.
+
+ Args:
+ requests: List of request dictionaries, each containing:
+ - prompt: User prompt text
+ - system_message: (optional) System message
+ - provider: (optional) Provider name
+ - model: (optional) Model name
+ - temperature: (optional) Temperature parameter
+ - max_tokens: (optional) Maximum generation tokens
+ - api_key: (optional) Specific API key to use
+ max_concurrent_per_key: Maximum concurrent requests per API key
+ max_total_concurrent: Maximum total concurrent requests
+ timeout: Timeout for each request in seconds
+
+ Returns:
+ List of response dictionaries in the same order as input requests
+ """
+ if not requests:
+ raise ValueError("Requests list cannot be empty")
+
+ logger.info(f"Starting parallel invoke for {len(requests)} requests")
+ logger.info(f"Max concurrent per key: {max_concurrent_per_key}, Max total: {max_total_concurrent}")
+
+ # Prepare request functions
+ def execute_single_request(request_data, index):
+ try:
+ # Extract parameters from request
+ prompt = request_data.get("prompt")
+ if not prompt:
+ return {
+ "index": index,
+ "error": "Missing prompt in request",
+ "request_data": request_data
+ }
+
+ system_message = request_data.get("system_message")
+ provider = request_data.get("provider")
+ model = request_data.get("model")
+ temperature = request_data.get("temperature")
+ max_tokens = request_data.get("max_tokens")
+ api_key = request_data.get("api_key")
+
+ # Call appropriate method
+ if system_message:
+ result = self.direct_invoke_with_system(
+ prompt=prompt,
+ system_message=system_message,
+ provider=provider,
+ model=model,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ api_key=api_key
+ )
+ else:
+ result = self.direct_invoke(
+ prompt=prompt,
+ provider=provider,
+ model=model,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ api_key=api_key
+ )
+
+ result["index"] = index
+ result["request_data"] = request_data
+ return result
+
+ except Exception as e:
+ logger.error(f"Error in request {index}: {str(e)}")
+ return {
+ "index": index,
+ "error": str(e),
+ "request_data": request_data
+ }
+
+ # Execute requests in parallel using ThreadPoolExecutor
+ results = [None] * len(requests)
+
+ # Limit concurrent requests
+ max_workers = min(max_total_concurrent, len(requests))
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # Submit all tasks
+ future_to_index = {
+ executor.submit(execute_single_request, req, idx): idx
+ for idx, req in enumerate(requests)
+ }
+
+ # Collect results
+ completed = 0
+ for future in as_completed(future_to_index, timeout=timeout):
+ try:
+ result = future.result()
+ index = result.get("index", future_to_index[future])
+ results[index] = result
+ completed += 1
+
+ if completed % 5 == 0 or completed == len(requests):
+ logger.info(f"Completed {completed}/{len(requests)} requests")
+
+ except Exception as e:
+ index = future_to_index[future]
+ logger.error(f"Request {index} failed: {str(e)}")
+ results[index] = {
+ "index": index,
+ "error": str(e),
+ "request_data": requests[index]
+ }
+
+ logger.info(f"Parallel invoke completed. {completed}/{len(requests)} requests finished")
+ return results
+
+ def batch_invoke_same_prompt(
+ self,
+ prompt: str,
+ count: int,
+ system_message: Optional[str] = None,
+ provider: Optional[str] = None,
+ model: Optional[str] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ max_concurrent_per_key: int = 3,
+ max_total_concurrent: int = 10
+ ) -> List[Dict[str, Any]]:
+ """
+ Batch invoke the same prompt multiple times for comparison or ensemble purposes.
+
+ Args:
+ prompt: User prompt text
+ count: Number of times to invoke the same prompt
+ system_message: (optional) System message
+ provider: (optional) Provider name
+ model: (optional) Model name
+ temperature: (optional) Temperature parameter
+ max_tokens: (optional) Maximum generation tokens
+ max_concurrent_per_key: Maximum concurrent requests per API key
+ max_total_concurrent: Maximum total concurrent requests
+
+ Returns:
+ List of response dictionaries
+ """
+ if count <= 0:
+ raise ValueError("Count must be positive")
+
+ logger.info(f"Batch invoking same prompt {count} times")
+
+ # Create request list
+ requests = []
+ for i in range(count):
+ request_data = {"prompt": prompt}
+ if system_message:
+ request_data["system_message"] = system_message
+ if provider:
+ request_data["provider"] = provider
+ if model:
+ request_data["model"] = model
+ if temperature is not None:
+ request_data["temperature"] = temperature
+ if max_tokens:
+ request_data["max_tokens"] = max_tokens
+
+ requests.append(request_data)
+
+ return self.parallel_invoke(
+ requests=requests,
+ max_concurrent_per_key=max_concurrent_per_key,
+ max_total_concurrent=max_total_concurrent
+ )
+
+ def get_api_keys_status(self, provider: Optional[str] = None) -> Dict[str, Any]:
+ """
+ Get status information about API keys.
+
+ Args:
+ provider: Specific provider to get status for, or None for all providers
+
+ Returns:
+ Dictionary containing API keys status and usage statistics
+ """
+ if provider:
+ if provider not in self.api_keys_by_provider:
+ return {"error": f"Provider {provider} not found"}
+
+ keys = self.api_keys_by_provider[provider]
+ return {
+ "provider": provider,
+ "total_keys": len(keys),
+ "api_keys": [f"{key[:8]}...{key[-4:]}" if len(key) > 12 else key for key in keys],
+ "key_usage_count": {
+ f"{key[:8]}...{key[-4:]}" if len(key) > 12 else key: self.provider_key_usage[provider].get(str(key), 0)
+ for key in keys
+ },
+ "key_last_used": {
+ f"{key[:8]}...{key[-4:]}" if len(key) > 12 else key: self.provider_key_last_used[provider].get(str(key), 0)
+ for key in keys
+ }
+ }
+ else:
+ # Return status for all providers
+ status = {}
+ for prov in self.api_keys_by_provider:
+ keys = self.api_keys_by_provider[prov]
+ status[prov] = {
+ "total_keys": len(keys),
+ "key_usage_count": {
+ f"{key[:8]}...{key[-4:]}" if len(key) > 12 else key: self.provider_key_usage[prov].get(str(key), 0)
+ for key in keys
+ }
+ }
+ return status
+
+ def update_default_provider(self, provider: str) -> None:
+ """
+ Update the default provider.
+
+ Args:
+ provider: The new default provider name
+ """
+ if provider not in self.api_keys_by_provider:
+ raise ValueError(f"Provider {provider} not configured")
+
+ self.default_provider = provider
+ self.api_keys = self.api_keys_by_provider.get(provider, [])
+ logger.info(f"Default provider updated to {provider}")
+
diff --git a/api/openai_client.py b/api/openai_client.py
index bc75ed586..8ddc84ec2 100644
--- a/api/openai_client.py
+++ b/api/openai_client.py
@@ -242,12 +242,19 @@ def track_completion_usage(
) -> CompletionUsage:
try:
- usage: CompletionUsage = CompletionUsage(
- completion_tokens=completion.usage.completion_tokens,
- prompt_tokens=completion.usage.prompt_tokens,
- total_tokens=completion.usage.total_tokens,
- )
- return usage
+ # Check if usage information is available
+ if completion.usage is not None:
+ usage: CompletionUsage = CompletionUsage(
+ completion_tokens=completion.usage.completion_tokens,
+ prompt_tokens=completion.usage.prompt_tokens,
+ total_tokens=completion.usage.total_tokens,
+ )
+ return usage
+ else:
+ # Usage info not available, return None values
+ return CompletionUsage(
+ completion_tokens=None, prompt_tokens=None, total_tokens=None
+ )
except Exception as e:
log.error(f"Error tracking the completion usage: {e}")
return CompletionUsage(
@@ -430,15 +437,22 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
# Get streaming response
stream_response = self.sync_client.chat.completions.create(**streaming_kwargs)
- # Accumulate all content from the stream
+ # Accumulate all content from the stream and collect usage info
accumulated_content = ""
id = ""
model = ""
created = 0
+ usage_info = None
+
for chunk in stream_response:
id = getattr(chunk, "id", None) or id
model = getattr(chunk, "model", None) or model
created = getattr(chunk, "created", 0) or created
+
+ # Collect usage information if available (usually in the last chunk)
+ if hasattr(chunk, 'usage') and chunk.usage is not None:
+ usage_info = chunk.usage
+
choices = getattr(chunk, "choices", [])
if len(choices) > 0:
delta = getattr(choices[0], "delta", None)
@@ -446,9 +460,25 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
text = getattr(delta, "content", None)
if text is not None:
accumulated_content += text or ""
+
+ # Create usage object with fallback values
+ # If usage info was not provided in stream, use None for token counts
+ # This allows the system to continue working even without usage data
+ from openai.types.completion_usage import CompletionUsage
+ if usage_info is not None:
+ final_usage = usage_info
+ else:
+ # Fallback: Create a CompletionUsage with None values
+ # Python allows None for optional numeric fields
+ final_usage = CompletionUsage(
+ completion_tokens=None,
+ prompt_tokens=None,
+ total_tokens=None
+ )
+
# Return the mock completion object that will be processed by the chat_completion_parser
return ChatCompletion(
- id = id,
+ id=id,
model=model,
created=created,
object="chat.completion",
@@ -456,7 +486,8 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
index=0,
finish_reason="stop",
message=ChatCompletionMessage(content=accumulated_content, role="assistant")
- )]
+ )],
+ usage=final_usage
)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
@@ -599,7 +630,7 @@ def _prepare_image_content(
gen = Generator(
model_client=OpenAIClient(),
- model_kwargs={"model": "gpt-4o", "stream": False},
+ model_kwargs={"model": "gpt-4.1", "stream": False},
)
gen_response = gen(prompt_kwargs)
print(f"gen_response: {gen_response}")
@@ -623,7 +654,7 @@ def _prepare_image_content(
setup_env()
openai_llm = adal.Generator(
- model_client=OpenAIClient(), model_kwargs={"model": "gpt-4o"}
+ model_client=OpenAIClient(), model_kwargs={"model": "gpt-4.1"}
)
resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"})
print(resopnse)
diff --git a/api/simple_chat.py b/api/simple_chat.py
index 41a184ed8..d3ce39edd 100644
--- a/api/simple_chat.py
+++ b/api/simple_chat.py
@@ -3,22 +3,15 @@
from typing import List, Optional
from urllib.parse import unquote
-import google.generativeai as genai
-from adalflow.components.model_client.ollama_client import OllamaClient
-from adalflow.core.types import ModelType
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
-from api.config import get_model_config, configs, OPENROUTER_API_KEY, OPENAI_API_KEY, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY
+from api.config import configs
from api.data_pipeline import count_tokens, get_file_content
-from api.openai_client import OpenAIClient
-from api.openrouter_client import OpenRouterClient
-from api.bedrock_client import BedrockClient
-from api.azureai_client import AzureAIClient
-from api.dashscope_client import DashscopeClient
from api.rag import RAG
+from api.llm import LLMService
from api.prompts import (
DEEP_RESEARCH_FIRST_ITERATION_PROMPT,
DEEP_RESEARCH_FINAL_ITERATION_PROMPT,
@@ -327,234 +320,47 @@ async def chat_completions_stream(request: ChatCompletionRequest):
prompt += f"\n{query}\n\n\nAssistant: "
- model_config = get_model_config(request.provider, request.model)["model_kwargs"]
-
+ # Add /no_think suffix for Ollama
if request.provider == "ollama":
prompt += " /no_think"
- model = OllamaClient()
- model_kwargs = {
- "model": model_config["model"],
- "stream": True,
- "options": {
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"],
- "num_ctx": model_config["num_ctx"]
- }
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "openrouter":
- logger.info(f"Using OpenRouter with model: {request.model}")
-
- # Check if OpenRouter API key is set
- if not OPENROUTER_API_KEY:
- logger.warning("OPENROUTER_API_KEY not configured, but continuing with request")
- # We'll let the OpenRouterClient handle this and return a friendly error message
-
- model = OpenRouterClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"]
- }
- # Only add top_p if it exists in the model config
- if "top_p" in model_config:
- model_kwargs["top_p"] = model_config["top_p"]
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "openai":
- logger.info(f"Using Openai protocol with model: {request.model}")
-
- # Check if an API key is set for Openai
- if not OPENAI_API_KEY:
- logger.warning("OPENAI_API_KEY not configured, but continuing with request")
- # We'll let the OpenAIClient handle this and return an error message
-
- # Initialize Openai client
- model = OpenAIClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"]
- }
- # Only add top_p if it exists in the model config
- if "top_p" in model_config:
- model_kwargs["top_p"] = model_config["top_p"]
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "bedrock":
- logger.info(f"Using AWS Bedrock with model: {request.model}")
-
- # Check if AWS credentials are set
- if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:
- logger.warning("AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY not configured, but continuing with request")
- # We'll let the BedrockClient handle this and return an error message
-
- # Initialize Bedrock client
- model = BedrockClient()
- model_kwargs = {
- "model": request.model,
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"]
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "azure":
- logger.info(f"Using Azure AI with model: {request.model}")
-
- # Initialize Azure AI client
- model = AzureAIClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"]
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "dashscope":
- logger.info(f"Using Dashscope with model: {request.model}")
-
- model = DashscopeClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"],
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM,
- )
- else:
- # Initialize Google Generative AI model (default provider)
- model = genai.GenerativeModel(
- model_name=model_config["model"],
- generation_config={
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"],
- "top_k": model_config["top_k"],
- },
- )
+ # Initialize LLM service
+ llm_service = LLMService(default_provider=request.provider)
# Create a streaming response
async def response_stream():
try:
- if request.provider == "ollama":
- # Get the response and handle it properly using the previously created api_kwargs
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Ollama
- async for chunk in response:
- text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
- if text and not text.startswith('model=') and not text.startswith('created_at='):
- text = text.replace('', '').replace('', '')
- yield text
- elif request.provider == "openrouter":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making OpenRouter API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from OpenRouter
- async for chunk in response:
- yield chunk
- except Exception as e_openrouter:
- logger.error(f"Error with OpenRouter API: {str(e_openrouter)}")
- yield f"\nError with OpenRouter API: {str(e_openrouter)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key."
- elif request.provider == "openai":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making Openai API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Openai
- async for chunk in response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- yield text
- except Exception as e_openai:
- logger.error(f"Error with Openai API: {str(e_openai)}")
- yield f"\nError with Openai API: {str(e_openai)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key."
- elif request.provider == "bedrock":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making AWS Bedrock API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle response from Bedrock (not streaming yet)
- if isinstance(response, str):
- yield response
- else:
- # Try to extract text from the response
- yield str(response)
- except Exception as e_bedrock:
- logger.error(f"Error with AWS Bedrock API: {str(e_bedrock)}")
- yield f"\nError with AWS Bedrock API: {str(e_bedrock)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials."
- elif request.provider == "azure":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making Azure AI API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Azure AI
- async for chunk in response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- yield text
- except Exception as e_azure:
- logger.error(f"Error with Azure AI API: {str(e_azure)}")
- yield f"\nError with Azure AI API: {str(e_azure)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values."
- elif request.provider == "dashscope":
- try:
- logger.info("Making Dashscope API call")
- response = await model.acall(
- api_kwargs=api_kwargs, model_type=ModelType.LLM
- )
- # DashscopeClient.acall with stream=True returns an async
- # generator of text chunks
- async for text in response:
- if text:
- yield text
- except Exception as e_dashscope:
- logger.error(f"Error with Dashscope API: {str(e_dashscope)}")
- yield (
- f"\nError with Dashscope API: {str(e_dashscope)}\n\n"
- "Please check that you have set the DASHSCOPE_API_KEY (and optionally "
- "DASHSCOPE_WORKSPACE_ID) environment variables with valid values."
- )
- else:
- # Google Generative AI (default provider)
- response = model.generate_content(prompt, stream=True)
- for chunk in response:
- if hasattr(chunk, "text"):
- yield chunk.text
+ # Use LLMService for unified streaming across all providers
+ logger.info(f"Using LLMService for provider: {request.provider}")
+
+ try:
+ async for chunk in llm_service.async_invoke_stream(
+ prompt=prompt,
+ provider=request.provider,
+ model=request.model
+ ):
+ # Post-process for Ollama to remove thinking tags
+ if request.provider == "ollama":
+ chunk = chunk.replace('', '').replace('', '')
+ if not chunk.startswith('model=') and not chunk.startswith('created_at='):
+ yield chunk
+ else:
+ yield chunk
+
+ except Exception as e_provider:
+ logger.error(f"Error with {request.provider} API: {str(e_provider)}")
+
+ # Provider-specific error messages
+ error_messages = {
+ "openrouter": f"\nError with OpenRouter API: {str(e_provider)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key.",
+ "openai": f"\nError with OpenAI API: {str(e_provider)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key.",
+ "bedrock": f"\nError with AWS Bedrock API: {str(e_provider)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials.",
+ "azure": f"\nError with Azure AI API: {str(e_provider)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values.",
+ "dashscope": f"\nError with Dashscope API: {str(e_provider)}\n\nPlease check that you have set the DASHSCOPE_API_KEY (and optionally DASHSCOPE_WORKSPACE_ID) environment variables with valid values."
+ }
+
+ error_msg = error_messages.get(request.provider, f"\nError with {request.provider} API: {str(e_provider)}")
+ yield error_msg
except Exception as e_outer:
logger.error(f"Error in streaming response: {str(e_outer)}")
@@ -580,154 +386,24 @@ async def response_stream():
if request.provider == "ollama":
simplified_prompt += " /no_think"
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback_response from Ollama
- async for chunk in fallback_response:
- text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
- if text and not text.startswith('model=') and not text.startswith('created_at='):
- text = text.replace('', '').replace('', '')
- yield text
- elif request.provider == "openrouter":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback OpenRouter API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback_response from OpenRouter
- async for chunk in fallback_response:
- yield chunk
- except Exception as e_fallback:
- logger.error(f"Error with OpenRouter API fallback: {str(e_fallback)}")
- yield f"\nError with OpenRouter API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key."
- elif request.provider == "openai":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback Openai API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback_response from Openai
- async for chunk in fallback_response:
- text = chunk if isinstance(chunk, str) else getattr(chunk, 'text', str(chunk))
- yield text
- except Exception as e_fallback:
- logger.error(f"Error with Openai API fallback: {str(e_fallback)}")
- yield f"\nError with Openai API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key."
- elif request.provider == "bedrock":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback AWS Bedrock API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle response from Bedrock
- if isinstance(fallback_response, str):
- yield fallback_response
- else:
- # Try to extract text from the response
- yield str(fallback_response)
- except Exception as e_fallback:
- logger.error(f"Error with AWS Bedrock API fallback: {str(e_fallback)}")
- yield f"\nError with AWS Bedrock API fallback: {str(e_fallback)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials."
- elif request.provider == "azure":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback Azure AI API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback response from Azure AI
- async for chunk in fallback_response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- yield text
- except Exception as e_fallback:
- logger.error(f"Error with Azure AI API fallback: {str(e_fallback)}")
- yield f"\nError with Azure AI API fallback: {str(e_fallback)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values."
- elif request.provider == "dashscope":
+ # Use LLMService for fallback with simplified prompt
try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM,
- )
-
- logger.info("Making fallback Dashscope API call")
- fallback_response = await model.acall(
- api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM
- )
-
- # DashscopeClient.acall (stream=True) returns an async
- # generator of text chunks
- async for text in fallback_response:
- if text:
- yield text
+ async for chunk in llm_service.async_invoke_stream(
+ prompt=simplified_prompt,
+ provider=request.provider,
+ model=request.model
+ ):
+ # Post-process for Ollama
+ if request.provider == "ollama":
+ chunk = chunk.replace('', '').replace('', '')
+ if not chunk.startswith('model=') and not chunk.startswith('created_at='):
+ yield chunk
+ else:
+ yield chunk
except Exception as e_fallback:
- logger.error(
- f"Error with Dashscope API fallback: {str(e_fallback)}"
- )
- yield (
- f"\nError with Dashscope API fallback: {str(e_fallback)}\n\n"
- "Please check that you have set the DASHSCOPE_API_KEY (and optionally "
- "DASHSCOPE_WORKSPACE_ID) environment variables with valid values."
- )
- else:
- # Google Generative AI fallback (default provider)
- model_config = get_model_config(request.provider, request.model)
- fallback_model = genai.GenerativeModel(
- model_name=model_config["model_kwargs"]["model"],
- generation_config={
- "temperature": model_config["model_kwargs"].get("temperature", 0.7),
- "top_p": model_config["model_kwargs"].get("top_p", 0.8),
- "top_k": model_config["model_kwargs"].get("top_k", 40),
- },
- )
-
- fallback_response = fallback_model.generate_content(
- simplified_prompt, stream=True
- )
- for chunk in fallback_response:
- if hasattr(chunk, "text"):
- yield chunk.text
+ logger.error(f"Error in fallback with LLMService: {str(e_fallback)}")
+ yield f"\nError in fallback: {str(e_fallback)}"
+
except Exception as e2:
logger.error(f"Error in fallback streaming response: {str(e2)}")
yield f"\nI apologize, but your request is too large for me to process. Please try a shorter query or break it into smaller parts."
diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py
index 5bd0c9ff2..36b94938d 100644
--- a/api/websocket_wiki.py
+++ b/api/websocket_wiki.py
@@ -3,27 +3,13 @@
from typing import List, Optional, Dict, Any
from urllib.parse import unquote
-import google.generativeai as genai
-from adalflow.components.model_client.ollama_client import OllamaClient
-from adalflow.core.types import ModelType
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from pydantic import BaseModel, Field
-from api.config import (
- get_model_config,
- configs,
- OPENROUTER_API_KEY,
- OPENAI_API_KEY,
- AWS_ACCESS_KEY_ID,
- AWS_SECRET_ACCESS_KEY,
-)
+from api.config import configs
from api.data_pipeline import count_tokens, get_file_content
-from api.bedrock_client import BedrockClient
-from api.openai_client import OpenAIClient
-from api.openrouter_client import OpenRouterClient
-from api.azureai_client import AzureAIClient
-from api.dashscope_client import DashscopeClient
from api.rag import RAG
+from api.llm import LLMService
# Configure logging
from api.logging_config import setup_logging
@@ -437,261 +423,50 @@ async def handle_websocket_chat(websocket: WebSocket):
prompt += f"\n{query}\n\n\nAssistant: "
- model_config = get_model_config(request.provider, request.model)["model_kwargs"]
-
+ # Add /no_think suffix for Ollama
if request.provider == "ollama":
prompt += " /no_think"
- model = OllamaClient()
- model_kwargs = {
- "model": model_config["model"],
- "stream": True,
- "options": {
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"],
- "num_ctx": model_config["num_ctx"]
- }
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "openrouter":
- logger.info(f"Using OpenRouter with model: {request.model}")
-
- # Check if OpenRouter API key is set
- if not OPENROUTER_API_KEY:
- logger.warning("OPENROUTER_API_KEY not configured, but continuing with request")
- # We'll let the OpenRouterClient handle this and return a friendly error message
-
- model = OpenRouterClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"]
- }
- # Only add top_p if it exists in the model config
- if "top_p" in model_config:
- model_kwargs["top_p"] = model_config["top_p"]
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "openai":
- logger.info(f"Using Openai protocol with model: {request.model}")
-
- # Check if an API key is set for Openai
- if not OPENAI_API_KEY:
- logger.warning("OPENAI_API_KEY not configured, but continuing with request")
- # We'll let the OpenAIClient handle this and return an error message
-
- # Initialize Openai client
- model = OpenAIClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"]
- }
- # Only add top_p if it exists in the model config
- if "top_p" in model_config:
- model_kwargs["top_p"] = model_config["top_p"]
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "bedrock":
- logger.info(f"Using AWS Bedrock with model: {request.model}")
-
- if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:
- logger.warning(
- "AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY not configured, but continuing with request")
-
- model = BedrockClient()
- model_kwargs = {
- "model": request.model,
- }
-
- for key in ["temperature", "top_p"]:
- if key in model_config:
- model_kwargs[key] = model_config[key]
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "azure":
- logger.info(f"Using Azure AI with model: {request.model}")
-
- # Initialize Azure AI client
- model = AzureAIClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"]
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- elif request.provider == "dashscope":
- logger.info(f"Using Dashscope with model: {request.model}")
-
- # Initialize Dashscope client
- model = DashscopeClient()
- model_kwargs = {
- "model": request.model,
- "stream": True,
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"]
- }
-
- api_kwargs = model.convert_inputs_to_api_kwargs(
- input=prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
- else:
- # Initialize Google Generative AI model
- model = genai.GenerativeModel(
- model_name=model_config["model"],
- generation_config={
- "temperature": model_config["temperature"],
- "top_p": model_config["top_p"],
- "top_k": model_config["top_k"]
- }
- )
+ # Initialize LLM service
+ llm_service = LLMService(default_provider=request.provider)
# Process the response based on the provider
try:
- if request.provider == "ollama":
- # Get the response and handle it properly using the previously created api_kwargs
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Ollama
- async for chunk in response:
- text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
- if text and not text.startswith('model=') and not text.startswith('created_at='):
- text = text.replace('', '').replace('', '')
- await websocket.send_text(text)
+ # Use LLMService for unified streaming across all providers
+ logger.info(f"Using LLMService for provider: {request.provider}")
+
+ try:
+ async for chunk in llm_service.async_invoke_stream(
+ prompt=prompt,
+ provider=request.provider,
+ model=request.model
+ ):
+ # Post-process for Ollama to remove thinking tags
+ if request.provider == "ollama":
+ chunk = chunk.replace('', '').replace('', '')
+ if not chunk.startswith('model=') and not chunk.startswith('created_at='):
+ await websocket.send_text(chunk)
+ else:
+ await websocket.send_text(chunk)
+
# Explicitly close the WebSocket connection after the response is complete
await websocket.close()
- elif request.provider == "openrouter":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making OpenRouter API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from OpenRouter
- async for chunk in response:
- await websocket.send_text(chunk)
- # Explicitly close the WebSocket connection after the response is complete
- await websocket.close()
- except Exception as e_openrouter:
- logger.error(f"Error with OpenRouter API: {str(e_openrouter)}")
- error_msg = f"\nError with OpenRouter API: {str(e_openrouter)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key."
- await websocket.send_text(error_msg)
- # Close the WebSocket connection after sending the error message
- await websocket.close()
- elif request.provider == "openai":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making Openai API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Openai
- async for chunk in response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- await websocket.send_text(text)
- # Explicitly close the WebSocket connection after the response is complete
- await websocket.close()
- except Exception as e_openai:
- logger.error(f"Error with Openai API: {str(e_openai)}")
- error_msg = f"\nError with Openai API: {str(e_openai)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key."
- await websocket.send_text(error_msg)
- # Close the WebSocket connection after sending the error message
- await websocket.close()
- elif request.provider == "bedrock":
- try:
- logger.info("Making AWS Bedrock API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- if isinstance(response, str):
- await websocket.send_text(response)
- else:
- await websocket.send_text(str(response))
- await websocket.close()
- except Exception as e_bedrock:
- logger.error(f"Error with AWS Bedrock API: {str(e_bedrock)}")
- error_msg = (
- f"\nError with AWS Bedrock API: {str(e_bedrock)}\n\n"
- "Please check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY "
- "environment variables with valid credentials."
- )
- await websocket.send_text(error_msg)
- await websocket.close()
- elif request.provider == "azure":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making Azure AI API call")
- response = await model.acall(api_kwargs=api_kwargs, model_type=ModelType.LLM)
- # Handle streaming response from Azure AI
- async for chunk in response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- await websocket.send_text(text)
- # Explicitly close the WebSocket connection after the response is complete
- await websocket.close()
- except Exception as e_azure:
- logger.error(f"Error with Azure AI API: {str(e_azure)}")
- error_msg = f"\nError with Azure AI API: {str(e_azure)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values."
- await websocket.send_text(error_msg)
- # Close the WebSocket connection after sending the error message
- await websocket.close()
- elif request.provider == "dashscope":
- try:
- # Get the response and handle it properly using the previously created api_kwargs
- logger.info("Making Dashscope API call")
- response = await model.acall(
- api_kwargs=api_kwargs, model_type=ModelType.LLM
- )
- # DashscopeClient.acall with stream=True returns an async
- # generator of plain text chunks
- async for text in response:
- if text:
- await websocket.send_text(text)
- # Explicitly close the WebSocket connection after the response is complete
- await websocket.close()
- except Exception as e_dashscope:
- logger.error(f"Error with Dashscope API: {str(e_dashscope)}")
- error_msg = (
- f"\nError with Dashscope API: {str(e_dashscope)}\n\n"
- "Please check that you have set the DASHSCOPE_API_KEY (and optionally "
- "DASHSCOPE_WORKSPACE_ID) environment variables with valid values."
- )
- await websocket.send_text(error_msg)
- # Close the WebSocket connection after sending the error message
- await websocket.close()
- else:
- # Google Generative AI (default provider)
- response = model.generate_content(prompt, stream=True)
- for chunk in response:
- if hasattr(chunk, 'text'):
- await websocket.send_text(chunk.text)
+
+ except Exception as e_provider:
+ logger.error(f"Error with {request.provider} API: {str(e_provider)}")
+
+ # Provider-specific error messages
+ error_messages = {
+ "openrouter": f"\nError with OpenRouter API: {str(e_provider)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key.",
+ "openai": f"\nError with OpenAI API: {str(e_provider)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key.",
+ "bedrock": f"\nError with AWS Bedrock API: {str(e_provider)}\n\nPlease check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables with valid credentials.",
+ "azure": f"\nError with Azure AI API: {str(e_provider)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values.",
+ "dashscope": f"\nError with Dashscope API: {str(e_provider)}\n\nPlease check that you have set the DASHSCOPE_API_KEY (and optionally DASHSCOPE_WORKSPACE_ID) environment variables with valid values."
+ }
+
+ error_msg = error_messages.get(request.provider, f"\nError with {request.provider} API: {str(e_provider)}")
+ await websocket.send_text(error_msg)
+ # Close the WebSocket connection after sending the error message
await websocket.close()
except Exception as e_outer:
@@ -718,163 +493,23 @@ async def handle_websocket_chat(websocket: WebSocket):
if request.provider == "ollama":
simplified_prompt += " /no_think"
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback_response from Ollama
- async for chunk in fallback_response:
- text = getattr(chunk, 'response', None) or getattr(chunk, 'text', None) or str(chunk)
- if text and not text.startswith('model=') and not text.startswith('created_at='):
- text = text.replace('', '').replace('', '')
- await websocket.send_text(text)
- elif request.provider == "openrouter":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback OpenRouter API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback_response from OpenRouter
- async for chunk in fallback_response:
- await websocket.send_text(chunk)
- except Exception as e_fallback:
- logger.error(f"Error with OpenRouter API fallback: {str(e_fallback)}")
- error_msg = f"\nError with OpenRouter API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENROUTER_API_KEY environment variable with a valid API key."
- await websocket.send_text(error_msg)
- elif request.provider == "openai":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback Openai API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback_response from Openai
- async for chunk in fallback_response:
- text = chunk if isinstance(chunk, str) else getattr(chunk, 'text', str(chunk))
- await websocket.send_text(text)
- except Exception as e_fallback:
- logger.error(f"Error with Openai API fallback: {str(e_fallback)}")
- error_msg = f"\nError with Openai API fallback: {str(e_fallback)}\n\nPlease check that you have set the OPENAI_API_KEY environment variable with a valid API key."
- await websocket.send_text(error_msg)
- elif request.provider == "bedrock":
- try:
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM,
- )
-
- logger.info("Making fallback AWS Bedrock API call")
- fallback_response = await model.acall(
- api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM
- )
-
- if isinstance(fallback_response, str):
- await websocket.send_text(fallback_response)
+ # Use LLMService for fallback with simplified prompt
+ try:
+ async for chunk in llm_service.async_invoke_stream(
+ prompt=simplified_prompt,
+ provider=request.provider,
+ model=request.model
+ ):
+ # Post-process for Ollama
+ if request.provider == "ollama":
+ chunk = chunk.replace('', '').replace('', '')
+ if not chunk.startswith('model=') and not chunk.startswith('created_at='):
+ await websocket.send_text(chunk)
else:
- await websocket.send_text(str(fallback_response))
- except Exception as e_fallback:
- logger.error(
- f"Error with AWS Bedrock API fallback: {str(e_fallback)}"
- )
- error_msg = (
- f"\nError with AWS Bedrock API fallback: {str(e_fallback)}\n\n"
- "Please check that you have set the AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY "
- "environment variables with valid credentials."
- )
- await websocket.send_text(error_msg)
- elif request.provider == "azure":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM
- )
-
- # Get the response using the simplified prompt
- logger.info("Making fallback Azure AI API call")
- fallback_response = await model.acall(api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM)
-
- # Handle streaming fallback response from Azure AI
- async for chunk in fallback_response:
- choices = getattr(chunk, "choices", [])
- if len(choices) > 0:
- delta = getattr(choices[0], "delta", None)
- if delta is not None:
- text = getattr(delta, "content", None)
- if text is not None:
- await websocket.send_text(text)
- except Exception as e_fallback:
- logger.error(f"Error with Azure AI API fallback: {str(e_fallback)}")
- error_msg = f"\nError with Azure AI API fallback: {str(e_fallback)}\n\nPlease check that you have set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_VERSION environment variables with valid values."
- await websocket.send_text(error_msg)
- elif request.provider == "dashscope":
- try:
- # Create new api_kwargs with the simplified prompt
- fallback_api_kwargs = model.convert_inputs_to_api_kwargs(
- input=simplified_prompt,
- model_kwargs=model_kwargs,
- model_type=ModelType.LLM,
- )
-
- logger.info("Making fallback Dashscope API call")
- fallback_response = await model.acall(
- api_kwargs=fallback_api_kwargs, model_type=ModelType.LLM
- )
-
- # DashscopeClient.acall (stream=True) returns an async
- # generator of text chunks
- async for text in fallback_response:
- if text:
- await websocket.send_text(text)
- except Exception as e_fallback:
- logger.error(
- f"Error with Dashscope API fallback: {str(e_fallback)}"
- )
- error_msg = (
- f"\nError with Dashscope API fallback: {str(e_fallback)}\n\n"
- "Please check that you have set the DASHSCOPE_API_KEY (and optionally "
- "DASHSCOPE_WORKSPACE_ID) environment variables with valid values."
- )
- await websocket.send_text(error_msg)
- else:
- # Google Generative AI fallback (default provider)
- model_config = get_model_config(request.provider, request.model)
- fallback_model = genai.GenerativeModel(
- model_name=model_config["model_kwargs"]["model"],
- generation_config={
- "temperature": model_config["model_kwargs"].get("temperature", 0.7),
- "top_p": model_config["model_kwargs"].get("top_p", 0.8),
- "top_k": model_config["model_kwargs"].get("top_k", 40),
- },
- )
-
- fallback_response = fallback_model.generate_content(
- simplified_prompt, stream=True
- )
- for chunk in fallback_response:
- if hasattr(chunk, "text"):
- await websocket.send_text(chunk.text)
+ await websocket.send_text(chunk)
+ except Exception as e_fallback:
+ logger.error(f"Error in fallback with LLMService: {str(e_fallback)}")
+ await websocket.send_text(f"\nError in fallback: {str(e_fallback)}")
except Exception as e2:
logger.error(f"Error in fallback streaming response: {str(e2)}")
await websocket.send_text(f"\nI apologize, but your request is too large for me to process. Please try a shorter query or break it into smaller parts.")
diff --git a/tests/unit/test_balance_loading.py b/tests/unit/test_balance_loading.py
new file mode 100755
index 000000000..d27213a49
--- /dev/null
+++ b/tests/unit/test_balance_loading.py
@@ -0,0 +1,209 @@
+#!/usr/bin/env python3
+"""
+Test configuration loading logic, particularly API keys placeholder replacement and comma-separated handling
+"""
+
+import os
+import sys
+from pathlib import Path
+
+# Add project root to Python path
+# __file__ is in tests/unit/, so we need to go up 2 levels to reach project root
+project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+print(f"project_root: {project_root}, __file__: {__file__}")
+sys.path.insert(0, project_root)
+
+# Load environment variables from .env file
+from dotenv import load_dotenv
+
+env_path = Path(project_root) / '.env'
+if env_path.exists():
+ load_dotenv(dotenv_path=env_path)
+ print(f"✅ Loaded environment variables from: {env_path}")
+else:
+ print(f"⚠️ Warning: .env file not found at {env_path}")
+ print(" Please create a .env file with OPENAI_API_KEYS and OPENAI_BASE_URL")
+ sys.exit(1)
+
+from api.config import load_api_keys_config, configs
+
+print("=" * 60)
+print("Test configuration loading logic")
+print("=" * 60)
+
+# Verify required environment variables are loaded
+required_env_vars = ['OPENAI_API_KEYS', 'OPENAI_BASE_URL']
+missing_vars = [var for var in required_env_vars if not os.environ.get(var)]
+
+if missing_vars:
+ print(f"\n❌ Error: Missing required environment variables: {', '.join(missing_vars)}")
+ print(" Please add them to your .env file")
+ sys.exit(1)
+
+print(f"\n✅ Environment variables loaded:")
+print(f" OPENAI_API_KEYS: {len(os.environ.get('OPENAI_API_KEYS', '').split(','))} key(s)")
+print(f" OPENAI_BASE_URL: {os.environ.get('OPENAI_BASE_URL')}")
+
+# Test 1: Directly load API keys configuration
+print("\n1️⃣ Test load_api_keys_config():")
+print("-" * 60)
+api_keys = load_api_keys_config()
+
+print(f"\nLoaded providers: {list(api_keys.keys())}")
+
+for provider, keys in api_keys.items():
+ if keys:
+ print(f"\n{provider}:")
+ print(f" Total {len(keys)} key(s)")
+ for i, key in enumerate(keys, 1):
+ masked_key = f"{key[:8]}...{key[-4:]}" if len(key) > 12 else key
+ print(f" [{i}] {masked_key}")
+
+# Test 2: Check api_keys in configs
+print("\n" + "=" * 60)
+print("2️⃣ Test configs['api_keys']:")
+print("-" * 60)
+
+if 'api_keys' in configs:
+ openai_keys = configs['api_keys'].get('openai', [])
+ print(f"\nOpenAI keys count: {len(openai_keys)}")
+
+ if len(openai_keys) > 4:
+ print("✅ Success! Correctly loaded at least 5 OpenAI API keys")
+ else:
+ print(f"❌ Failure! Expected at least 5 keys, actual {len(openai_keys)} keys")
+else:
+ print("❌ 'api_keys' not found in configs")
+
+# Test 3: Verify LLMService can correctly use these keys
+print("\n" + "=" * 60)
+print("3️⃣ Test LLMService loading:")
+print("-" * 60)
+
+try:
+ from api.llm import LLMService
+
+ llm_service = LLMService(default_provider="openai")
+ status = llm_service.get_api_keys_status("openai")
+
+ print(f"\nOpenAI keys count in LLMService: {status['total_keys']}")
+ print(f"API Keys (masked):")
+ for key in status['api_keys']:
+ print(f" - {key}")
+
+ if status['total_keys'] > 4:
+ print("\n✅ Success! LLMService correctly loaded at least 5 OpenAI API keys")
+ print("✅ Load balancing functionality is ready")
+ else:
+ print(f"\n❌ Failure! Expected at least 5 keys, actual {status['total_keys']} keys")
+
+except Exception as e:
+ print(f"❌ Error loading LLMService: {e}")
+ import traceback
+ traceback.print_exc()
+
+# Test 4: Test real LLM calls and load balancing
+print("\n" + "=" * 60)
+print("4️⃣ Test LLM Service Call & Load Balancing:")
+print("-" * 60)
+
+try:
+ print("\n🚀 Sending 10 parallel requests to test load balancing...")
+ print(" Question: 奶油扇贝意大利面的做法")
+ print("")
+
+ # Reset usage stats before testing
+ llm_service.reset_key_usage_stats("openai")
+
+ # Batch invoke same prompt 10 times
+ results = llm_service.batch_invoke_same_prompt(
+ prompt="请简要介绍奶油扇贝意大利面的做法(200字以内)",
+ count=10,
+ provider="openai",
+ model="gpt-4.1", # Use a faster/cheaper model for testing
+ temperature=0.7,
+ max_tokens=200,
+ max_concurrent_per_key=2, # Limit concurrent per key to better see load balancing
+ max_total_concurrent=5
+ )
+
+ # Check results
+ successful_count = sum(1 for r in results if r and "error" not in r and "content" in r)
+ failed_count = len(results) - successful_count
+
+ print(f"\n📊 Request Results:")
+ print(f" Total: {len(results)}")
+ print(f" ✅ Successful: {successful_count}")
+ print(f" ❌ Failed: {failed_count}")
+
+ # Get and display API key usage statistics
+ print(f"\n📈 API Key Usage Statistics (Load Balancing):")
+ print("-" * 60)
+ status = llm_service.get_api_keys_status("openai")
+
+ usage_data = status.get("key_usage_count", {})
+ total_usage = sum(usage_data.values())
+
+ # Sort by usage count
+ sorted_usage = sorted(usage_data.items(), key=lambda x: x[1], reverse=True)
+
+ for i, (key_masked, count) in enumerate(sorted_usage, 1):
+ percentage = (count / total_usage * 100) if total_usage > 0 else 0
+ bar = "█" * int(count * 2) # Visual bar
+ print(f" Key {i}: {key_masked}")
+ print(f" 使用次数: {count:2d} ({percentage:5.1f}%) {bar}")
+
+ print(f"\n Total API Calls: {total_usage}")
+
+ # Check if load balancing is working
+ if total_usage > 0:
+ max_usage = max(usage_data.values())
+ min_usage = min(usage_data.values())
+ balance_ratio = min_usage / max_usage if max_usage > 0 else 0
+
+ print(f" Max usage per key: {max_usage}")
+ print(f" Min usage per key: {min_usage}")
+ print(f" Balance ratio: {balance_ratio:.2f} (1.0 = perfect balance)")
+
+ if balance_ratio >= 0.5:
+ print(" ✅ Load balancing is working well!")
+ else:
+ print(" ⚠️ Load balancing could be improved")
+
+ # Display formatted responses
+ print(f"\n📝 Response Content:")
+ print("=" * 60)
+
+ for i, result in enumerate(results, 1):
+ print(f"\n【Response {i}】")
+ print("-" * 60)
+
+ if result and "error" not in result:
+ response_text = result.get("content", "")
+ if response_text:
+ # Clean up response text
+ response_text = response_text.strip()
+ print(f"{response_text}")
+ else:
+ print("⚠️ Empty response")
+ else:
+ error_msg = result.get("error", "Unknown error") if result else "No result"
+ print(f"❌ Error: {error_msg}")
+
+ print("\n" + "=" * 60)
+
+ if successful_count == 10:
+ print("✅ Test 4 PASSED: All 10 requests successful!")
+ elif successful_count > 0:
+ print(f"⚠️ Test 4 PARTIAL: {successful_count}/10 requests successful")
+ else:
+ print("❌ Test 4 FAILED: No successful requests")
+
+except Exception as e:
+ print(f"❌ Error during LLM service call test: {e}")
+ import traceback
+ traceback.print_exc()
+
+print("\n" + "=" * 60)
+print("All Tests Completed")
+print("=" * 60)