-
Notifications
You must be signed in to change notification settings - Fork 934
Expand file tree
/
Copy pathquery_rewriter.py
More file actions
81 lines (76 loc) · 3.4 KB
/
query_rewriter.py
File metadata and controls
81 lines (76 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import json
from openai.types.responses import Response, ResponseFunctionToolCall
def build_search_function() -> dict:
return {
"type": "function",
"name": "search_database",
"description": "Search PostgreSQL database for relevant products based on user query",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to use for full text search, e.g. 'red shoes'",
},
"price_filter": {
"type": "object",
"description": "Filter search results based on price of the product",
"properties": {
"comparison_operator": {
"type": "string",
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
},
"value": {
"type": "number",
"description": "Value to compare against, e.g. 30",
},
},
},
"brand_filter": {
"type": "object",
"description": "Filter search results based on brand of the product",
"properties": {
"comparison_operator": {
"type": "string",
"description": "Operator to compare the column value, either '=' or '!='",
},
"value": {
"type": "string",
"description": "Value to compare against, e.g. AirStrider",
},
},
},
},
"required": ["search_query"],
},
}
def extract_search_arguments(original_user_query: str, response: Response):
search_query = None
filters = []
tool_calls = [item for item in response.output if isinstance(item, ResponseFunctionToolCall)]
if tool_calls:
for tool_call in tool_calls:
if tool_call.name == "search_database":
arg = json.loads(tool_call.arguments)
search_query = arg.get("search_query", original_user_query)
if "price_filter" in arg and arg["price_filter"] and isinstance(arg["price_filter"], dict):
price_filter = arg["price_filter"]
filters.append(
{
"column": "price",
"comparison_operator": price_filter["comparison_operator"],
"value": price_filter["value"],
}
)
if "brand_filter" in arg and arg["brand_filter"] and isinstance(arg["brand_filter"], dict):
brand_filter = arg["brand_filter"]
filters.append(
{
"column": "brand",
"comparison_operator": brand_filter["comparison_operator"],
"value": brand_filter["value"],
}
)
elif response.output_text:
search_query = response.output_text.strip()
return search_query, filters