From 550bcf725ce7e4c191364c403142c3f20a1bd098 Mon Sep 17 00:00:00 2001 From: Paul Shippy Date: Sat, 4 Apr 2026 06:12:40 +0000 Subject: [PATCH] Try to get web search working --- back/bots/models/chat.py | 104 +++++++++++++++++++++++++++++---------- back/test_integration.py | 9 ++-- 2 files changed, 83 insertions(+), 30 deletions(-) diff --git a/back/bots/models/chat.py b/back/bots/models/chat.py index be1a18b..a3eddfb 100644 --- a/back/bots/models/chat.py +++ b/back/bots/models/chat.py @@ -2,10 +2,9 @@ from django.db import models import uuid from langchain_aws import ChatBedrock -from langchain_core.messages import HumanMessage, SystemMessage, AIMessage +from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage from langchain_core.tools import tool from langchain_core.callbacks.base import BaseCallbackHandler -from langchain.agents import create_agent from tavily import TavilyClient import logging import base64 @@ -116,18 +115,12 @@ def web_search(query: str) -> str: return f"Error during search: {str(e)}" - # Create chat model + # Create chat model with tool binding chat_model = ChatBedrock(model_id=self.ai.model_id) tools = [web_search] - # Create modern agent with tool calling support - # This is the recommended approach per LangChain docs - agent = create_agent( - model=chat_model, - tools=tools, - system_prompt=self.get_system_message(), - debug=settings.DEBUG - ) + # Bind tools to the model for proper tool calling + model_with_tools = chat_model.bind_tools(tools) # Extract text input from message_list for agent agent_input = self._extract_agent_input(message_list) @@ -135,28 +128,87 @@ def web_search(query: str) -> str: logger.info(f"Invoking agent with input: {agent_input[:100]}...") logger.info("🤖 AGENT_INVOKE_START: web_search tool available") - # Invoke agent - the CompiledStateGraph handles tool loop internally - response = agent.invoke({"messages": [HumanMessage(content=agent_input)]}) + # Build and run the agent loop manually + messages = [SystemMessage(content=self.get_system_message()), HumanMessage(content=agent_input)] + + # Agentagent loop - keep invoking until no more tool calls + max_iterations = 5 + iteration = 0 + + while iteration < max_iterations: + iteration += 1 + logger.info(f"🤖 AGENT_LOOP_ITERATION: {iteration}") + + # Call the model + response = model_with_tools.invoke(messages) + messages.append(response) + + # Check if model wants to call a tool + if not response.tool_calls: + logger.info(f"🤖 AGENT_LOOP_COMPLETE: no more tool calls after {iteration} iterations") + break + + # Process tool calls + for tool_call in response.tool_calls: + tool_name = tool_call["name"] + tool_args = tool_call["args"] + logger.info(f"🔍 AGENT_TOOL_CALL: {tool_name} with args: {tool_args}") + + # Execute the tool + if tool_name == "web_search": + tool_result = web_search.invoke(tool_args) + else: + tool_result = f"Unknown tool: {tool_name}" + + logger.info(f"🔍 AGENT_TOOL_RESULT: {tool_result[:100]}") + + # Add tool result to messages + messages.append(ToolMessage( + content=tool_result, + tool_call_id=tool_call["id"], + name=tool_name + )) - logger.info("🤖 AGENT_INVOKE_COMPLETE: got response") + logger.info("🤖 AGENT_LOOP_COMPLETE: extracting final response") - # Extract response text from the agent result - # The response is a dict with 'messages' key containing final messages + # Extract response text from the final message + # The messages list now contains: System, HumanMessage, AIMessage (with tool call), ToolMessage, AIMessage (final response) response_text = "" usage_metadata = {"input_tokens": 0, "output_tokens": 0} - if isinstance(response, dict) and "messages" in response: - for msg in reversed(response["messages"]): - if isinstance(msg, AIMessage): + # Find the last AIMessage (should be the final response) + for msg in reversed(messages): + if isinstance(msg, AIMessage) and not msg.tool_calls: + # This is the final response (no tool calls) + if isinstance(msg.content, str): response_text = msg.content - # Extract token usage from the message metadata - if hasattr(msg, 'usage_metadata') and msg.usage_metadata: - usage_metadata = msg.usage_metadata + elif isinstance(msg.content, list): + # Extract text from content list + text_parts = [] + for item in msg.content: + if isinstance(item, dict) and item.get('type') == 'text': + text_parts.append(item.get('text', '')) + elif isinstance(item, str): + text_parts.append(item) + response_text = "".join(text_parts).strip() + + # Extract token usage + if hasattr(msg, 'usage_metadata') and msg.usage_metadata: + usage_metadata = msg.usage_metadata + + logger.info(f"🤖 FINAL_RESPONSE: {len(response_text)} chars") + break + + if not response_text: + # Fallback: get the last message + for msg in reversed(messages): + if isinstance(msg, AIMessage): + if isinstance(msg.content, str): + response_text = msg.content + elif isinstance(msg.content, list): + text_parts = [item.get('text', '') for item in msg.content if isinstance(item, dict) and item.get('type') == 'text'] + response_text = "".join(text_parts).strip() break - elif isinstance(response, dict) and "output" in response: - response_text = response["output"] - else: - response_text = str(response) message_order = self.messages.count() diff --git a/back/test_integration.py b/back/test_integration.py index a98b497..5785c7d 100644 --- a/back/test_integration.py +++ b/back/test_integration.py @@ -6,16 +6,17 @@ import os import sys import django -import requests -from django.contrib.auth.models import User -from rest_framework_simplejwt.tokens import RefreshToken -from bots.models import Bot, AiModel os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server.settings') sys.path.insert(0, '/home/ubuntu/repos/bots/back') django.setup() +import requests +from django.contrib.auth.models import User +from rest_framework_simplejwt.tokens import RefreshToken +from bots.models import Bot, AiModel + def run_integration_test(): print("=" * 60)