-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlanggraph_tool_backend.py
More file actions
119 lines (95 loc) · 3.24 KB
/
langgraph_tool_backend.py
File metadata and controls
119 lines (95 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# backend.py
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import tool
from dotenv import load_dotenv
import sqlite3
import requests
import os
load_dotenv()
os.environ["LANGCHAIN_PROJECT"] = "Langgraph Tool Backend"
# -------------------
# 1. LLM
# -------------------
llm = ChatOpenAI()
# -------------------
# 2. Tools
# -------------------
# Tools
search_tool = DuckDuckGoSearchRun(region="us-en")
@tool
def calculator(first_num: float, second_num: float, operation: str) -> dict:
"""
Perform a basic arithmetic operation on two numbers.
Supported operations: add, sub, mul, div
"""
try:
if operation == "add":
result = first_num + second_num
elif operation == "sub":
result = first_num - second_num
elif operation == "mul":
result = first_num * second_num
elif operation == "div":
if second_num == 0:
return {"error": "Division by zero is not allowed"}
result = first_num / second_num
else:
return {"error": f"Unsupported operation '{operation}'"}
return {"first_num": first_num, "second_num": second_num, "operation": operation, "result": result}
except Exception as e:
return {"error": str(e)}
@tool
def get_stock_price(symbol: str) -> dict:
"""
Fetch latest stock price for a given symbol (e.g. 'AAPL', 'TSLA')
using Alpha Vantage with API key in the URL.
"""
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey=C9PE94QUEW9VWGFM"
r = requests.get(url)
return r.json()
tools = [search_tool, get_stock_price, calculator]
llm_with_tools = llm.bind_tools(tools)
# -------------------
# 3. State
# -------------------
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
# -------------------
# 4. Nodes
# -------------------
def chat_node(state: ChatState):
"""LLM node that may answer or request a tool call."""
messages = state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
tool_node = ToolNode(tools)
# -------------------
# 5. Checkpointer
# -------------------
conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)
# -------------------
# 6. Graph
# -------------------
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "chat_node")
graph.add_conditional_edges("chat_node",tools_condition)
graph.add_edge('tools', 'chat_node')
chatbot = graph.compile(checkpointer=checkpointer)
# -------------------
# 7. Helper
# -------------------
def retrieve_all_threads():
all_threads = set()
for checkpoint in checkpointer.list(None):
all_threads.add(checkpoint.config["configurable"]["thread_id"])
return list(all_threads)