diff --git a/api/config.py b/api/config.py index 49dfcf7b0..c09302a0e 100644 --- a/api/config.py +++ b/api/config.py @@ -152,7 +152,7 @@ def load_embedder_config(): embedder_config = load_json_config("embedder.json") # Process client classes - for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock"]: + for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock", "embedder_azure"]: if key in embedder_config and "client_class" in embedder_config[key]: class_name = embedder_config[key]["client_class"] if class_name in CLIENT_CLASSES: @@ -174,6 +174,8 @@ def get_embedder_config(): return configs.get("embedder_google", {}) elif embedder_type == 'ollama' and 'embedder_ollama' in configs: return configs.get("embedder_ollama", {}) + elif embedder_type == 'azure' and 'embedder_azure' in configs: + return configs.get("embedder_azure", {}) else: return configs.get("embedder", {}) @@ -235,21 +237,41 @@ def is_bedrock_embedder(): client_class = embedder_config.get("client_class", "") return client_class == "BedrockClient" +def is_azure_embedder(): + """ + Check if the current embedder configuration uses AzureAIClient. + + Returns: + bool: True if using AzureAIClient, False otherwise + """ + embedder_config = get_embedder_config() + if not embedder_config: + return False + + model_client = embedder_config.get("model_client") + if model_client: + return model_client.__name__ == "AzureAIClient" + + client_class = embedder_config.get("client_class", "") + return client_class == "AzureAIClient" + def get_embedder_type(): """ Get the current embedder type based on configuration. - + Returns: - str: 'bedrock', 'ollama', 'google', or 'openai' (default) + str: 'bedrock', 'ollama', 'google', 'azure', or 'openai' (default) """ - if is_bedrock_embedder(): - return 'bedrock' - elif is_ollama_embedder(): - return 'ollama' - elif is_google_embedder(): - return 'google' - else: - return 'openai' + embedder_checks = { + 'bedrock': is_bedrock_embedder, + 'ollama': is_ollama_embedder, + 'google': is_google_embedder, + 'azure': is_azure_embedder, + } + for embedder_type, check_func in embedder_checks.items(): + if check_func(): + return embedder_type + return 'openai' # Load repository and file filters configuration def load_repo_config(): @@ -341,7 +363,7 @@ def load_lang_config(): # Update embedder configuration if embedder_config: - for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock", "retriever", "text_splitter"]: + for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock", "embedder_azure", "retriever", "text_splitter"]: if key in embedder_config: configs[key] = embedder_config[key] diff --git a/api/config/embedder.json b/api/config/embedder.json index 0101ac083..c8ce231a7 100644 --- a/api/config/embedder.json +++ b/api/config/embedder.json @@ -30,6 +30,14 @@ "dimensions": 256 } }, + "embedder_azure": { + "client_class": "AzureAIClient", + "batch_size": 100, + "model_kwargs": { + "model": "text-embedding-3-small", + "dimensions": 256 + } + }, "retriever": { "top_k": 20 }, diff --git a/api/main.py b/api/main.py index fe083f550..cc9af3de4 100644 --- a/api/main.py +++ b/api/main.py @@ -43,21 +43,17 @@ def patched_watch(*args, **kwargs): import uvicorn -# Check for required environment variables -required_env_vars = ['GOOGLE_API_KEY', 'OPENAI_API_KEY'] -missing_vars = [var for var in required_env_vars if not os.environ.get(var)] -if missing_vars: - logger.warning(f"Missing environment variables: {', '.join(missing_vars)}") - logger.warning("Some functionality may not work correctly without these variables.") - -# Configure Google Generative AI +# Configure providers based on settings +from api.config import configs import google.generativeai as genai -from api.config import GOOGLE_API_KEY -if GOOGLE_API_KEY: - genai.configure(api_key=GOOGLE_API_KEY) -else: - logger.warning("GOOGLE_API_KEY not configured") +# Only configure Google if it's being used as a provider +if configs.get("default_provider") == "google": + from api.config import GOOGLE_API_KEY + if GOOGLE_API_KEY: + genai.configure(api_key=GOOGLE_API_KEY) + else: + logger.warning("GOOGLE_API_KEY not configured but Google is the default provider") if __name__ == "__main__": # Get port from environment variable or use default diff --git a/api/tools/embedder.py b/api/tools/embedder.py index 050d63547..9c0b8c030 100644 --- a/api/tools/embedder.py +++ b/api/tools/embedder.py @@ -22,6 +22,8 @@ def get_embedder(is_local_ollama: bool = False, use_google_embedder: bool = Fals embedder_config = configs["embedder_google"] elif embedder_type == 'bedrock': embedder_config = configs["embedder_bedrock"] + elif embedder_type == 'azure': + embedder_config = configs["embedder_azure"] else: # default to openai embedder_config = configs["embedder"] elif is_local_ollama: @@ -37,6 +39,8 @@ def get_embedder(is_local_ollama: bool = False, use_google_embedder: bool = Fals embedder_config = configs["embedder_ollama"] elif current_type == 'google': embedder_config = configs["embedder_google"] + elif current_type == 'azure': + embedder_config = configs["embedder_azure"] else: embedder_config = configs["embedder"] diff --git a/api/websocket_wiki.py b/api/websocket_wiki.py index 5bd0c9ff2..087210efd 100644 --- a/api/websocket_wiki.py +++ b/api/websocket_wiki.py @@ -79,8 +79,8 @@ async def handle_websocket_chat(websocket: WebSocket): if hasattr(last_message, 'content') and last_message.content: tokens = count_tokens(last_message.content, request.provider == "ollama") logger.info(f"Request size: {tokens} tokens") - if tokens > 8000: - logger.warning(f"Request exceeds recommended token limit ({tokens} > 7500)") + if tokens > 9000: + logger.warning(f"Request exceeds recommended token limit ({tokens} > 9000)") input_too_large = True # Create a new RAG instance for this request