-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
339 lines (289 loc) · 12 KB
/
main.py
File metadata and controls
339 lines (289 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
"""
FastAPI backend for the Research Assistant application
Following the architecture outlined in the README
"""
import uvicorn
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import os
from pathlib import Path
from datetime import datetime
import json
import uuid
# Import our custom modules
from scripts.eps_reasoning_agent import EPSReasoningAgent
from scripts.eps_retriever import EPSRetriever
from scripts.ingest_eps_data import EPSGraphBuilder
# from evaluation.run_evaluation import Evaluator
# Initialize components
app = FastAPI(title="Research Assistant API", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"], # Allow Next.js dev server
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global instances (will be initialized on startup)
reasoning_agent = None
retriever = None
evaluator = None
# Simple in-memory session storage for chat history
# In production, this should be stored in Redis or a database
chat_sessions: Dict[str, List[Dict[str, str]]] = {}
class QueryRequest(BaseModel):
query: str
chat_history: Optional[List[Dict[str, str]]] = []
persona_override: Optional[str] = None
session_id: Optional[str] = None
class QueryResponse(BaseModel):
response: str
context_used: List[Dict[str, Any]]
quality_grade: float
retrieval_method: Optional[str]
retrieval_performed: bool
sources: List[Dict[str, str]]
session_id: Optional[str] = None
class IngestionRequest(BaseModel):
directory: str = "data/research_papers"
recreate_indexes: bool = False
class EvaluationRequest(BaseModel):
dataset_path: Optional[str] = None
output_path: Optional[str] = "evaluation/results/api_evaluation.json"
class SystemStatus(BaseModel):
neo4j_connected: bool
ollama_ready: bool
redis_connected: bool
eps_count: int
evaluation_count: int
@app.on_event("startup")
async def startup_event():
"""Initialize components on startup"""
global reasoning_agent, retriever, evaluator
try:
reasoning_agent = EPSReasoningAgent()
retriever = EPSRetriever()
# evaluator = Evaluator(
# test_dataset_path=Path("evaluation/datasets/research_assistant_v1.json"),
# trace_db_path=Path("evaluation/trace.db")
# )
print("✓ All components initialized")
except Exception as e:
print(f"✗ Component initialization failed: {e}")
@app.get("/api/health")
async def health_check():
"""Basic health check"""
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
@app.post("/api/chat")
async def chat(request: QueryRequest) -> QueryResponse:
"""Main chat endpoint with GraphRAG"""
if not reasoning_agent:
raise HTTPException(status_code=503, detail="Reasoning agent not initialized")
try:
# Session management: Get or create session
session_id = request.session_id
if not session_id or session_id not in chat_sessions:
session_id = str(uuid.uuid4())
chat_sessions[session_id] = []
# Use provided chat_history for backward compatibility, or build from session
if request.chat_history and len(request.chat_history) > len(chat_sessions[session_id]):
# If user provided more complete history, use it and update session
chat_history_to_use = request.chat_history
chat_sessions[session_id] = request.chat_history.copy()
else:
# Use session history
chat_history_to_use = chat_sessions[session_id].copy()
# Add current user message if not already present
user_message = {"role": "user", "content": request.query}
if not chat_history_to_use or chat_history_to_use[-1] != user_message:
chat_history_to_use.append(user_message)
# Generate response
result = reasoning_agent.generate_response(
request.query,
chat_history_to_use
)
# Add both user and assistant messages to session history
chat_sessions[session_id].append(user_message)
chat_sessions[session_id].append({"role": "assistant", "content": result['response']})
# Format sources for frontend
sources = []
for doc in result['context_used']:
sources.append({
'title': doc.get('filename', 'Unknown Document'),
'authors': doc.get('document_type', 'Unknown'),
'year': doc.get('filename', 'Unknown')[:10] if doc.get('filename') else 'Unknown',
'relevance_score': f"{doc.get('relevance_score', 0.0):.3f}",
'retrieval_method': str(doc.get('retrieval_method', 'unknown'))
})
return QueryResponse(
response=result['response'],
context_used=result['context_used'],
quality_grade=result['quality_grade'],
retrieval_method=result.get('retrieval_method'),
retrieval_performed=result.get('retrieval_performed', False),
sources=sources,
session_id=session_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Chat processing failed: {str(e)}")
@app.get("/api/search")
async def search_papers(query: str, limit: int = 10):
"""Direct paper search endpoint"""
if not retriever:
raise HTTPException(status_code=503, detail="Retriever not initialized")
try:
results = retriever.retrieve_context(query, limit=limit)
return {"results": results, "query": query}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
@app.post("/api/ingest")
async def ingest_papers(request: IngestionRequest, background_tasks: BackgroundTasks):
"""Ingest research papers"""
try:
# Run ingestion in background
background_tasks.add_task(run_ingestion, request.directory, request.recreate_indexes)
return {"message": f"Started ingestion from {request.directory}"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Ingestion failed: {str(e)}")
@app.post("/api/evaluate")
async def run_evaluation_endpoint(request: EvaluationRequest, background_tasks: BackgroundTasks):
"""Run evaluation"""
if not evaluator:
raise HTTPException(status_code=503, detail="Evaluator not initialized")
try:
# Run evaluation in background
background_tasks.add_task(run_evaluation_task, request.dataset_path, request.output_path)
return {"message": "Started evaluation"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Evaluation failed: {str(e)}")
@app.get("/api/status")
async def get_system_status() -> SystemStatus:
"""Get comprehensive system status"""
try:
# Check Neo4j connection
neo4j_connected = False
eps_count = 0
try:
if retriever and retriever.driver:
with retriever.driver.session() as session:
result = session.run("MATCH (d:EPSDocument) RETURN count(d) as count")
eps_count = result.single()["count"]
neo4j_connected = True
except:
pass
# Check Redis
redis_connected = False
try:
import redis
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379"))
r.ping()
redis_connected = True
except:
pass
# Check Ollama (simplified)
ollama_ready = True # Assume ready if service started
# Get evaluation count
evaluation_count = 0
try:
trace_file = Path("evaluation/trace.db")
if trace_file.exists():
with open(trace_file, 'r') as f:
data = json.load(f)
evaluation_count = len(data)
except:
pass
return SystemStatus(
neo4j_connected=neo4j_connected,
ollama_ready=ollama_ready,
redis_connected=redis_connected,
eps_count=eps_count,
evaluation_count=evaluation_count
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
@app.get("/api/evaluation-results")
async def get_evaluation_results():
"""Get latest evaluation results"""
try:
results_path = Path("evaluation/results/evaluation_output.json")
if results_path.exists():
with open(results_path, 'r') as f:
return json.load(f)
else:
return {"error": "No evaluation results found"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load results: {str(e)}")
# Background tasks
def run_ingestion(directory: str, recreate_indexes: bool):
"""Run EPS data ingestion in background"""
print(f"Starting background EPS ingestion from {directory}")
try:
builder = EPSGraphBuilder()
if recreate_indexes:
# Drop existing indexes first
try:
with builder.driver.session() as session:
session.run("DROP INDEX eps_document_embeddings IF EXISTS")
except:
pass
builder.create_vector_indexes()
# Since we're ingesting CSV, use ingest_eps_csv instead of ingest_reddit_directory
if directory.endswith('.csv'):
builder.ingest_eps_csv(Path(directory))
if recreate_indexes:
builder.create_vector_indexes()
# Create similarity relationships
builder.create_similarity_relationships()
else:
print(f"✗ Directory {directory} is not a CSV file. EPS ingestion expects a CSV file path.")
print("✓ Background EPS ingestion completed")
except Exception as e:
print(f"✗ Background EPS ingestion failed: {e}")
def run_evaluation_task(dataset_path: str = None, output_path: str = "evaluation/results/api_evaluation.json"):
"""Run evaluation in background"""
print("Starting background evaluation")
try:
if not evaluator:
print("✗ Evaluator not initialized")
return
# Use default sample queries if no dataset provided
if dataset_path and Path(dataset_path).exists():
with open(dataset_path, 'r') as f:
dataset = json.load(f)
queries = dataset.get('queries', [])
else:
queries = [
{
'query': 'What are the main approaches to attention mechanisms in deep learning?',
'persona': 'researcher',
'ground_truth_chunk_ids': ['attention paper'],
'reference_answer': 'Attention mechanisms in deep learning...',
'complexity_score': 0.7
},
{
'query': 'How do transformer models handle long-range dependencies?',
'persona': 'student',
'ground_truth_chunk_ids': ['transformer paper'],
'reference_answer': 'Transformer models use...',
'complexity_score': 0.6
}
]
results = evaluator.run_evaluation(
queries=queries,
output_path=Path(output_path)
)
print("✓ Background evaluation completed")
except Exception as e:
print(f"✗ Background evaluation failed: {e}")
# Mount static files if they exist (for production)
frontend_path = Path("frontend/build")
if frontend_path.exists():
app.mount("/", StaticFiles(directory=frontend_path, html=True), name="frontend")
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)