-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_data.py
More file actions
104 lines (93 loc) · 4.21 KB
/
query_data.py
File metadata and controls
104 lines (93 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from typing import List
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
# Path to the persisted Chroma database (must match populate_database.py)
CHROMA_PATH = "chroma"
EMBED_MODEL = "text-embedding-3-small"
CHAT_MODEL = "gpt-4o-mini" # You can also use gpt-4o or gpt-4.1-mini
def ensure_api_key() -> str:
"""Check if the OpenAI API key is available in the environment."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError(
"OPENAI_API_KEY not set. Please add it in your Run configuration or Conda environment."
)
return api_key
def load_db(embeddings: OpenAIEmbeddings) -> Chroma:
"""Load the existing Chroma vector database."""
if not os.path.isdir(CHROMA_PATH):
raise RuntimeError(
f"Vector database directory not found: {CHROMA_PATH}\n"
f"Please run populate_database.py first to build the index."
)
return Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
def format_docs_for_context(docs: List) -> str:
"""Format retrieved document chunks as context for the LLM, with numbering and source info."""
lines = []
for i, d in enumerate(docs, start=1):
meta = d.metadata or {}
title = meta.get("source") or meta.get("file_path") or meta.get("title") or "document"
page = meta.get("page") or meta.get("page_number") or meta.get("pages")
head = f"[{i}] {title}" + (f", p.{page}" if page is not None else "")
# Truncate long text blocks to keep context concise
text = (d.page_content or "").strip().replace("\n", " ")
if len(text) > 1200:
text = text[:1200] + " …"
lines.append(f"{head}\n{text}\n")
return "\n".join(lines)
def make_messages(query: str, context: str):
"""Compose the system and user messages for the Chat model."""
sys = (
"You are a precise assistant. Answer ONLY using the provided context.\n"
"If the answer is not contained in the context, say 'I don't know'.\n"
"Keep answers concise and cite sources as [1], [2], etc., at the end of sentences."
)
user = (
f"Question:\n{query}\n\n"
f"Context (numbered):\n{context}\n\n"
"Instructions:\n"
"- Use the most relevant facts from the context.\n"
"- Add inline citations like [1], [2] matching the context numbering.\n"
"- If multiple places support a claim, you may cite multiple numbers."
)
return [{"role": "system", "content": sys}, {"role": "user", "content": user}]
def answer_with_sources(query: str, k: int = 5) -> None:
"""Retrieve relevant chunks from the database and generate an answer with citations."""
# 1) Environment & models
api_key = ensure_api_key()
embeddings = OpenAIEmbeddings(model=EMBED_MODEL, api_key=api_key)
db = load_db(embeddings)
retriever = db.as_retriever(search_kwargs={"k": k})
# 2) Retrieval
docs = retriever.invoke(query)
if not docs:
print("(No relevant documents found. Please ensure the index is built or try another query.)")
return
# 3) Prepare context and generate answer
context = format_docs_for_context(docs)
llm = ChatOpenAI(model=CHAT_MODEL, temperature=0, api_key=api_key)
messages = make_messages(query, context)
resp = llm.invoke(messages)
# 4) Output answer and sources
print("\n================== Answer ==================\n")
print(resp.content.strip())
print("\n================== Sources =================\n")
for i, d in enumerate(docs, start=1):
meta = d.metadata or {}
title = meta.get("source") or meta.get("file_path") or meta.get("title") or "document"
page = meta.get("page") or meta.get("page_number") or meta.get("pages")
print(f"[{i}] {title}" + (f", p.{page}" if page is not None else ""))
def main():
"""Support both command-line arguments and interactive input."""
import sys
if len(sys.argv) > 1:
query = " ".join(sys.argv[1:])
else:
query = input("Enter your query: ").strip()
if not query:
print("Empty query. Exiting.")
return
answer_with_sources(query, k=5)
if __name__ == "__main__":
main()