diff --git a/docker-compose.yml b/docker-compose.yml index 7f0e6d9..1723d88 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -41,6 +41,21 @@ services: networks: - net + rag-engine: + build: + context: ./rag-engine + dockerfile: Dockerfile + ports: + - 8001:8000 + restart: on-failure + environment: + - 'LLM_HOST=http://aimengpt-api:8000' + - 'CHROMA_HOST=http://chroma-server:8000' + volumes: + - rag_data:/app/data + networks: + - net + aimengpt-ui: build: context: ./ui @@ -49,8 +64,9 @@ services: - 3000:3000 restart: on-failure environment: - - 'OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXX' + - 'OPENAI_API_KEY=sk-XXX...XXXX' - 'OPENAI_API_HOST=http://aimengpt-api:8000' + - 'RAG_ENGINE_HOST=http://rag-engine:8000' - 'DEFAULT_MODEL=/models/${MODEL_NAME:-llama-2-7b-chat.bin}' - 'WAIT_HOSTS=aimengpt-api:8000' - 'WAIT_TIMEOUT=${WAIT_TIMEOUT:-3600}' @@ -62,3 +78,5 @@ volumes: driver: local backups: driver: local + rag_data: + driver: local diff --git a/rag-engine/Dockerfile b/rag-engine/Dockerfile new file mode 100644 index 0000000..52e89c2 --- /dev/null +++ b/rag-engine/Dockerfile @@ -0,0 +1,37 @@ +# Sci-RAG Engine — Backend Service for AimenGPT +# +# Provides Llama Index-powered RAG with Semantic Scholar integration, +# citation tracking, and AI document access. +# +# Build: docker build -t rag-engine -f Dockerfile . +# Run: docker run -p 8000:8000 rag-engine + +FROM python:3.11-slim + +WORKDIR /app + +# Install system dependencies for document processing +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements and install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY src/ ./src/ +COPY config/ ./config/ + +# Create data directories +RUN mkdir -p data/documents data/uploads data/chroma_db + +# Expose the API port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ + CMD python3 -c "import requests; requests.get('http://localhost:8000/health')" || exit 1 + +# Run the server +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"] diff --git a/rag-engine/README.md b/rag-engine/README.md new file mode 100644 index 0000000..d2a0c1e --- /dev/null +++ b/rag-engine/README.md @@ -0,0 +1,28 @@ +# Sci-RAG Engine for AimenGPT + +A production-ready Llama Index-powered RAG backend that replaces the existing Chroma-only retrieval with a full scientific document understanding pipeline. + +## What it adds + +- **Llama Index** — Hierarchical document parsing and retrieval (as requested in the bounty) +- **Semantic Scholar + arXiv** — Search and import external references alongside uploaded documents +- **Smart Citations** — Every answer cites sources with confidence scores +- **Document Unification** — Uploaded PDFs and external references in one searchable index +- **Secure AI Access** — Token-based document access for AI agents + +## Endpoints + +| Endpoint | Description | +|----------|-------------| +| `POST /query` | Ask a question, get answer + citations | +| `POST /documents/upload` | Upload a PDF, DOCX, TXT, or MD file | +| `GET /documents` | List all indexed documents | +| `POST /references/search` | Search Semantic Scholar + arXiv | +| `POST /references/import` | Import a paper as a document | +| `GET /health` | Service health | + +## How it replaces the existing flow + +**Before:** Frontend API routes → ChromaDB directly → raw chunks → LLM + +**After:** Frontend API routes → **rag-engine** (Llama Index + citations) → ChromaDB → LLM diff --git a/rag-engine/config/settings.yaml b/rag-engine/config/settings.yaml new file mode 100644 index 0000000..ed30cfa --- /dev/null +++ b/rag-engine/config/settings.yaml @@ -0,0 +1,47 @@ +# Sci-RAG Pipeline Configuration + +llm: + provider: openrouter + model: deepseek/deepseek-v4-flash + temperature: 0.1 + max_tokens: 4096 + +embeddings: + provider: huggingface + model: sentence-transformers/all-MiniLM-L6-v2 + dimension: 384 + batch_size: 32 + +vector_store: + type: chroma + persist_directory: data/chroma_db + collection_name: scientific_docs + +document_manager: + upload_dir: data/uploads + allowed_extensions: [.pdf, .docx, .txt, .md, .tex] + chunk_size: 1024 + chunk_overlap: 200 + max_document_size_mb: 50 + +semantic_scholar: + api_base: https://api.semanticscholar.org/v1 + max_results: 10 + cache_ttl_hours: 24 + +citation: + min_confidence: 0.6 + max_sources_per_claim: 5 + include_confidence: true + style: inline + +performance: + cache_enabled: true + cache_ttl_seconds: 3600 + async_mode: true + max_concurrent_requests: 10 + +server: + host: 0.0.0.0 + port: 8000 + workers: 4 diff --git a/rag-engine/docs/architecture.md b/rag-engine/docs/architecture.md new file mode 100644 index 0000000..35eeb58 --- /dev/null +++ b/rag-engine/docs/architecture.md @@ -0,0 +1,158 @@ +# Sci-RAG Pipeline — Architecture Document + +## System Overview + +The Sci-RAG Pipeline is a production-ready Retrieval-Augmented Generation system designed specifically for scientific and research workflows. It unifies uploaded documents and Semantic Scholar references into a single queryable knowledge base, with proper citation tracking and AI-native access patterns. + +## Core Architecture + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ FastAPI Server (src/main.py) │ +│ ┌─────────────┐ ┌──────────────┐ ┌──────────────┐ ┌─────────┐ │ +│ │ /query │ │ /documents/* │ │ /references │ │ /access │ │ +│ └──────┬──────┘ └──────┬───────┘ └──────┬───────┘ └────┬────┘ │ +└─────────┼────────────────┼─────────────────┼───────────────┼────────┘ + │ │ │ │ +┌─────────▼────────────────▼─────────────────▼───────────────▼────────┐ +│ Orchestration Layer │ +│ Routes requests to appropriate components, handles errors │ +└─────────────────────────────────────────────────────────────────────┘ + │ │ │ │ +┌─────────▼─────────┐ ┌──▼──────────────────▼──┐ ┌──────────▼────────┐ +│ DocumentManager │ │ RAG Pipeline Core │ │ AIAccessLayer │ +│ │ │ (Llama Index) │ │ │ +│ • File ingestion │ │ • VectorStoreIndex │ │ • Token auth │ +│ • SS/arXiv search │ │ • RetrieverQueryEngine │ │ • Permission │ +│ • Deduplication │ │ • Node parsing │ │ • Rate limiting │ +│ • Manifest store │ │ • Similarity postproc │ │ • Audit logging │ +└───────────────────┘ └──────────┬──────────────┘ └──────────────────┘ + │ + ┌────────────▼────────────┐ + │ CitationEngine │ + │ │ + │ • Source tracking │ + │ • Confidence scoring │ + │ • Inline/footnote fmt │ + │ • Validation │ + └─────────────────────────┘ +``` + +## Data Flow + +### Query Flow +``` +User/AI → POST /query {"question": "..."} + → RAGPipeline.query() + → VectorIndexRetriever (top_k docs) + → LLM synthesis with source context + → CitationEngine.record_claim() + → Response with answer + citations +``` + +### Document Ingestion Flow +``` +Upload → POST /documents/upload + → File saved to data/uploads/ + → DocumentManager.ingest_uploaded_file() + → Text extraction (PDF/DOCX/TXT/MD/TEX) + → Deduplication by content hash + → Manifest persistence + → RAGPipeline.refresh_index() + → Rebuild VectorStoreIndex +``` + +### Reference Import Flow +``` +Search → POST /references/search?query="..." + → DocumentManager.search_semantic_scholar() + → DocumentManager.search_arxiv() + → Returns structured paper list + +Import → POST /references/import + → DocumentManager.import_from_semantic_scholar() + → Refresh index +``` + +## Component Details + +### DocumentManager +- **Purpose**: Unified document lifecycle management +- **Storage**: JSON manifest + file system for uploaded files +- **Deduplication**: MD5 content hash +- **Sources**: Uploaded files (PDF, DOCX, TXT, MD, TEX), Semantic Scholar API, arXiv API +- **Key methods**: `add_document()`, `ingest_uploaded_file()`, `search_semantic_scholar()`, `search_arxiv()` + +### RAG Pipeline (Llama Index) +- **Index**: `VectorStoreIndex` with ChromaDB persistence +- **Embeddings**: HuggingFace `all-MiniLM-L6-v2` (384-dim) +- **LLM**: OpenRouter (configurable model) via `OpenRouter` LLM class +- **Retrieval**: `VectorIndexRetriever` with configurable `top_k` +- **Post-processing**: `SimilarityPostprocessor` (0.5 cutoff) +- **Query Engine**: `RetrieverQueryEngine` with synthesized responses + +### CitationEngine +- **Tracking**: Every query records all sources used +- **Confidence**: Aggregate confidence from max individual source relevance +- **Formats**: Inline (text markers), footnote, session report +- **Validation**: Cross-checks citations against document store +- **Deduplication**: Registry of unique sources across session + +### AIAccessLayer +- **Authentication**: Token-based (UUID v4) +- **Permission Levels**: READ_ONLY, READ_QUERY, FULL_ACCESS +- **Rate Limiting**: 30 requests per 60-second window +- **Audit**: Full activity log with timestamps, actions, status +- **Token Controls**: Expiration time, max query count, revocation + +## Configuration + +See `config/settings.yaml` for all configurable parameters. Key settings: + +| Setting | Default | Description | +|---------|---------|-------------| +| `llm.model` | deepseek/deepseek-v4-flash | LLM for answer synthesis | +| `embeddings.model` | all-MiniLM-L6-v2 | Text embedding model | +| `vector_store.type` | chroma | Vector database backend | +| `document_manager.chunk_size` | 1024 | Document chunk size | +| `citation.min_confidence` | 0.6 | Minimum citation confidence | +| `server.port` | 8000 | API server port | + +## Deployment + +### Production +```bash +# Install dependencies +pip install -r requirements.txt + +# Run with uvicorn +uvicorn src.main:app --host 0.0.0.0 --port 8000 --workers 4 + +# Or directly +python src/main.py +``` + +### Test +```bash +# Run tests +pytest tests/ -v +``` + +## Security Considerations + +1. **Access tokens** — All AI-agent interactions require tokens with explicit permission levels +2. **Rate limiting** — Prevents abuse of the query endpoint +3. **Audit logging** — All access is logged for review +4. **File validation** — Only allowed extensions are processed; max file size enforced +5. **CORS** — Configured permissive by default; restrict in production + +## Extensibility + +The pipeline is designed for component swapping: + +| Component | Default | Alternatives | +|-----------|---------|--------------| +| LLM | OpenRouter | OpenAI, Anthropic, local (Ollama) | +| Embeddings | HuggingFace MiniLM | OpenAI Embeddings, Cohere | +| Vector Store | ChromaDB | Pinecone, Weaviate, Qdrant, Simple | +| Document Store | JSON manifest | SQLite, PostgreSQL, S3 | diff --git a/rag-engine/requirements.txt b/rag-engine/requirements.txt new file mode 100644 index 0000000..a88e4e6 --- /dev/null +++ b/rag-engine/requirements.txt @@ -0,0 +1,31 @@ +# Sci-RAG Pipeline Dependencies + +# Core +llama-index>=0.12.0 +llama-index-core>=0.12.0 +llama-index-llms-openrouter +llama-index-embeddings-huggingface +llama-index-vector-stores-chroma +llama-index-readers-file + +# Document Processing +pypdf>=4.0 +python-docx>=1.1.0 + +# Vector Store +chromadb>=0.5.0 + +# Embeddings +sentence-transformers>=2.2 + +# External References +arxiv>=2.0 +aiohttp>=3.9 + +# API Server +fastapi>=0.109 +uvicorn[standard]>=0.29 +pydantic>=2.0 + +# Utilities +pyyaml>=6.0 diff --git a/rag-engine/src/__init__.py b/rag-engine/src/__init__.py new file mode 100644 index 0000000..6db7757 --- /dev/null +++ b/rag-engine/src/__init__.py @@ -0,0 +1 @@ +"""Sci-RAG Engine — Backend RAG service for AimenGPT.""" diff --git a/rag-engine/src/__pycache__/__init__.cpython-313.pyc b/rag-engine/src/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..cf50328 Binary files /dev/null and b/rag-engine/src/__pycache__/__init__.cpython-313.pyc differ diff --git a/rag-engine/src/__pycache__/ai_access_layer.cpython-313.pyc b/rag-engine/src/__pycache__/ai_access_layer.cpython-313.pyc new file mode 100644 index 0000000..19a98b4 Binary files /dev/null and b/rag-engine/src/__pycache__/ai_access_layer.cpython-313.pyc differ diff --git a/rag-engine/src/__pycache__/citation_engine.cpython-313.pyc b/rag-engine/src/__pycache__/citation_engine.cpython-313.pyc new file mode 100644 index 0000000..37ab5d7 Binary files /dev/null and b/rag-engine/src/__pycache__/citation_engine.cpython-313.pyc differ diff --git a/rag-engine/src/__pycache__/config.cpython-313.pyc b/rag-engine/src/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000..12dbc5b Binary files /dev/null and b/rag-engine/src/__pycache__/config.cpython-313.pyc differ diff --git a/rag-engine/src/__pycache__/document_manager.cpython-313.pyc b/rag-engine/src/__pycache__/document_manager.cpython-313.pyc new file mode 100644 index 0000000..863a4ce Binary files /dev/null and b/rag-engine/src/__pycache__/document_manager.cpython-313.pyc differ diff --git a/rag-engine/src/ai_access_layer.py b/rag-engine/src/ai_access_layer.py new file mode 100644 index 0000000..c25e229 --- /dev/null +++ b/rag-engine/src/ai_access_layer.py @@ -0,0 +1,300 @@ +""" +Sci-RAG Pipeline — AI Access Layer. + +Provides secure, authenticated pathways for AI agents to interact with +user documents. Implements rate limiting, permission scoping, and +activity auditing. + +This is the bridge that allows the AI to access documents directly +while maintaining security boundaries. +""" + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Callable, Dict, List, Optional + +from .config import settings + +logger = logging.getLogger(__name__) + + +class PermissionLevel(Enum): + """Access permission levels for AI-document interactions.""" + READ_ONLY = "read_only" + READ_QUERY = "read_query" # Can read + query + FULL_ACCESS = "full_access" # Read, query, upload, manage + + +@dataclass +class AccessToken: + """An access token for AI-document interaction.""" + token_id: str + permission: PermissionLevel + created_at: float = field(default_factory=time.time) + expires_at: Optional[float] = None + max_queries: Optional[int] = None + query_count: int = 0 + is_active: bool = True + + @property + def is_expired(self) -> bool: + if self.expires_at and time.time() > self.expires_at: + return True + if self.max_queries and self.query_count >= self.max_queries: + return True + return False + + def use(self) -> bool: + """Mark one query use. Returns True if still valid.""" + self.query_count += 1 + if self.is_expired: + self.is_active = False + return False + return True + + +@dataclass +class AccessLogEntry: + """A single access log entry for auditing.""" + timestamp: str + agent_id: str + action: str + document_id: Optional[str] + status: str # "granted" | "denied" | "error" + details: str = "" + + +class AIAccessLayer: + """ + Manages AI agent access to the document system. + + Features: + - Token-based authentication for AI agents + - Permission scoping (read-only, query, full access) + - Rate limiting per token + - Full activity audit log + - Automatic token expiration + """ + + def __init__(self, document_manager=None, rag_pipeline=None): + self._doc_manager = document_manager + self._rag_pipeline = rag_pipeline + self._tokens: Dict[str, AccessToken] = {} + self._audit_log: List[AccessLogEntry] = [] + + # Rate limiting + self._rate_limit_window = 60 # seconds + self._rate_limit_max = 30 # max requests per window + self._request_timestamps: List[float] = [] + + # ── Token Management ── + + def create_token( + self, + permission: PermissionLevel = PermissionLevel.READ_QUERY, + expires_in_seconds: Optional[int] = 3600, + max_queries: Optional[int] = 100, + ) -> AccessToken: + """Create a new access token for an AI agent.""" + import uuid + + token = AccessToken( + token_id=uuid.uuid4().hex[:16], + permission=permission, + expires_at=time.time() + expires_in_seconds if expires_in_seconds else None, + max_queries=max_queries, + ) + self._tokens[token.token_id] = token + logger.info(f"Access token created: {token.token_id[:8]}... ({permission.value})") + return token + + def revoke_token(self, token_id: str) -> bool: + """Revoke an access token.""" + if token_id in self._tokens: + self._tokens[token_id].is_active = False + logger.info(f"Token revoked: {token_id[:8]}...") + return True + return False + + def validate_token(self, token_id: str) -> Optional[AccessToken]: + """Validate and return a token, or None if invalid.""" + token = self._tokens.get(token_id) + if not token: + return None + if not token.is_active or token.is_expired: + return None + return token + + # ── Access Control ── + + def _check_rate_limit(self) -> bool: + """Check if the current request is within rate limits.""" + now = time.time() + # Clean old timestamps + self._request_timestamps = [ + t for t in self._request_timestamps + if now - t < self._rate_limit_window + ] + if len(self._request_timestamps) >= self._rate_limit_max: + return False + self._request_timestamps.append(now) + return True + + def _log_access( + self, + agent_id: str, + action: str, + document_id: Optional[str], + status: str, + details: str = "", + ): + """Log an access event.""" + entry = AccessLogEntry( + timestamp=datetime.utcnow().isoformat(), + agent_id=agent_id, + action=action, + document_id=document_id, + status=status, + details=details, + ) + self._audit_log.append(entry) + # Keep log manageable + if len(self._audit_log) > 1000: + self._audit_log = self._audit_log[-500:] + + # ── Document Actions ── + + def query_documents(self, token_id: str, question: str) -> Dict: + """ + Query documents through the RAG pipeline with access control. + + Args: + token_id: Valid access token. + question: The query question. + + Returns: + Query result or error dict. + """ + # Validate token + token = self.validate_token(token_id) + if not token: + self._log_access(token_id, "query", None, "denied", "Invalid or expired token") + return {"error": "Access denied: invalid or expired token"} + + # Check permission + if token.permission == PermissionLevel.READ_ONLY: + self._log_access(token_id, "query", None, "denied", "Insufficient permissions") + return {"error": "Access denied: read-only tokens cannot query"} + + # Rate limit + if not self._check_rate_limit(): + self._log_access(token_id, "query", None, "denied", "Rate limit exceeded") + return {"error": "Rate limit exceeded. Try again later."} + + # Mark token usage + if not token.use(): + self._log_access(token_id, "query", None, "denied", "Token exhausted") + return {"error": "Token usage exhausted"} + + # Execute query + try: + if self._rag_pipeline: + result = self._rag_pipeline.query(question) + status = "granted" if "error" not in result else "error" + self._log_access( + token_id, "query", None, status, + f"Queried: {question[:100]} → {len(result.get('citations', []))} sources" + ) + return result + else: + return {"error": "RAG pipeline not configured"} + except Exception as e: + self._log_access(token_id, "query", None, "error", str(e)) + return {"error": f"Query failed: {str(e)}"} + + def list_documents(self, token_id: str, source: Optional[str] = None) -> Dict: + """List available documents (read-only action).""" + token = self.validate_token(token_id) + if not token: + return {"error": "Access denied: invalid or expired token"} + + if not self._check_rate_limit(): + return {"error": "Rate limit exceeded"} + + try: + docs = self._doc_manager.list_documents(source) if self._doc_manager else [] + self._log_access(token_id, "list", None, "granted") + return {"documents": docs, "count": len(docs)} + except Exception as e: + return {"error": str(e)} + + def get_document(self, token_id: str, doc_id: str) -> Dict: + """Get a specific document (read-only action).""" + token = self.validate_token(token_id) + if not token: + return {"error": "Access denied: invalid or expired token"} + + try: + doc = self._doc_manager.get_document(doc_id) if self._doc_manager else None + if doc: + self._log_access(token_id, "read", doc_id, "granted") + return doc.to_dict() + else: + self._log_access(token_id, "read", doc_id, "denied", "Document not found") + return {"error": "Document not found"} + except Exception as e: + return {"error": str(e)} + + def upload_document(self, token_id: str, title: str, content: str, **kwargs) -> Dict: + """Upload a document (requires full_access permission).""" + token = self.validate_token(token_id) + if not token: + return {"error": "Access denied: invalid or expired token"} + + if token.permission != PermissionLevel.FULL_ACCESS: + self._log_access(token_id, "upload", None, "denied", "Insufficient permissions") + return {"error": "Access denied: only full_access tokens can upload"} + + try: + doc = self._doc_manager.add_text_document(title=title, content=content, **kwargs) if self._doc_manager else None + if doc: + self._log_access(token_id, "upload", doc.doc_id, "granted") + return {"success": True, "doc_id": doc.doc_id, "title": doc.title} + return {"error": "Failed to add document"} + except Exception as e: + return {"error": str(e)} + + # ── Audit ── + + def get_audit_log(self, limit: int = 50) -> List[Dict]: + """Get the recent access audit log.""" + return [ + { + "timestamp": e.timestamp, + "agent": e.agent_id[:8] + "...", + "action": e.action, + "status": e.status, + "details": e.details, + } + for e in self._audit_log[-limit:] + ] + + def get_token_status(self, token_id: str) -> Optional[Dict]: + """Get the status of a specific token.""" + token = self._tokens.get(token_id) + if not token: + return None + return { + "token_id": token.token_id[:8] + "...", + "permission": token.permission.value, + "created_at": datetime.fromtimestamp(token.created_at).isoformat(), + "expires_at": datetime.fromtimestamp(token.expires_at).isoformat() if token.expires_at else "never", + "queries_used": token.query_count, + "queries_limit": token.max_queries, + "is_active": token.is_active, + "is_expired": token.is_expired, + } diff --git a/rag-engine/src/citation_engine.py b/rag-engine/src/citation_engine.py new file mode 100644 index 0000000..baba1ef --- /dev/null +++ b/rag-engine/src/citation_engine.py @@ -0,0 +1,227 @@ +""" +Sci-RAG Pipeline — Citation Engine. + +Generates intelligent, context-aware citations from both user documents +and external references (Semantic Scholar). Tracks source provenance +and confidence scores for every claim. +""" + +import logging +import re +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field + +from .config import settings + +logger = logging.getLogger(__name__) + + +@dataclass +class Source: + """A single source used to support a claim.""" + doc_id: str + title: str + authors: List[str] + source_url: Optional[str] + source_type: str # "upload" | "semantic_scholar" + relevance_score: float # 0.0 to 1.0 + excerpt: str = "" + year: Optional[int] = None + + +@dataclass +class Citation: + """A complete citation for a claim in the generated answer.""" + claim_text: str + sources: List[Source] = field(default_factory=list) + confidence: float = 0.0 # Aggregate confidence across sources + + def add_source(self, source: Source): + self.sources.append(source) + # Update aggregate confidence (max of source scores, weighted by count) + if self.sources: + self.confidence = max( + self.confidence, + source.relevance_score, + ) + + +class CitationEngine: + """ + Manages citation generation and tracking. + + Features: + - Tracks every source used in every generated answer + - Assigns confidence scores to each source + - Generates formatted citations in multiple styles (inline, footnote, endnote) + - Deduplicates sources across multiple claims + - Validates citations against the document store + """ + + def __init__(self): + self._session_citations: List[Citation] = [] + self._source_registry: Dict[str, Source] = {} # dedup by doc_id + + # ── Citation Recording ── + + def record_citation(self, citation: Citation): + """Record a citation for tracking and auditing.""" + self._session_citations.append(citation) + + # Register sources for deduplication + for source in citation.sources: + self._source_registry[source.doc_id] = source + + def record_claim( + self, + claim_text: str, + sources: List[Tuple[str, str, float]], # [(doc_id, excerpt, relevance)] + document_lookup: Dict[str, dict], + ) -> Citation: + """ + Convenience method: record a claim with raw source references. + + Args: + claim_text: The text of the claim being made. + sources: List of (doc_id, excerpt, relevance_score) tuples. + document_lookup: Dict mapping doc_id to source metadata. + + Returns: + The created Citation. + """ + citation = Citation(claim_text=claim_text) + + for doc_id, excerpt, relevance in sources: + meta = document_lookup.get(doc_id, {}) + source = Source( + doc_id=doc_id, + title=meta.get("title", "Unknown"), + authors=meta.get("authors", []), + source_url=meta.get("source_url"), + source_type=meta.get("source", "upload"), + relevance_score=min(relevance, 1.0), + excerpt=excerpt[:200], + year=meta.get("metadata", {}).get("year"), + ) + citation.add_source(source) + + self.record_citation(citation) + return citation + + # ── Citation Formatting ── + + def format_inline_citations(self, text: str) -> str: + """ + Add inline citation markers to generated text. + + Finds patterns like [citation:N] and replaces them with formatted + inline citations like [1][2]. + """ + def _replace_citation(match): + idx = int(match.group(1)) + if 0 < idx <= len(self._session_citations): + cit = self._session_citations[idx - 1] + sources_formatted = [] + for s in cit.sources: + authors_short = ", ".join(s.authors[:2]) + if len(s.authors) > 2: + authors_short += " et al." + sources_formatted.append(f"{authors_short} ({s.source_type})") + return f" [{'; '.join(sources_formatted)}]" + return match.group(0) + + text = re.sub(r'\[citation:(\d+)\]', _replace_citation, text) + return text + + def format_footnotes(self, text: str) -> Tuple[str, List[str]]: + """ + Convert inline citation markers to footnote references. + + Returns: + (text_with_footnotes, footnotes_list) + """ + footnotes = [] + footnoted_text = text + + def _replace_footnote(match): + idx = int(match.group(1)) + if 0 < idx <= len(self._session_citations): + cit = self._session_citations[idx - 1] + fn_text = f"^{idx}" + sources_detail = [] + for s in cit.sources: + authors_str = ", ".join(s.authors[:3]) + if len(s.authors) > 3: + authors_str += " et al." + sources_detail.append( + f"{authors_str} — \"{s.excerpt[:80]}...\" " + f"(confidence: {s.relevance_score:.0%})" + ) + footnotes.append(f"{idx}. {'; '.join(sources_detail)}") + return f"{match.group(0)}[^{idx}]" + return match.group(0) + + footnoted_text = re.sub(r'\[citation:(\d+)\]', _replace_footnote, footnoted_text) + return footnoted_text, footnotes + + def get_session_report(self) -> Dict: + """Generate a complete citation report for the session.""" + return { + "total_citations": len(self._session_citations), + "unique_sources": len(self._source_registry), + "average_confidence": ( + sum(c.confidence for c in self._session_citations) / len(self._session_citations) + if self._session_citations else 0.0 + ), + "citations": [ + { + "claim_text": c.claim_text[:150], + "confidence": c.confidence, + "source_count": len(c.sources), + "sources": [ + { + "title": s.title[:80], + "source_type": s.source_type, + "relevance": s.relevance_score, + "url": s.source_url, + } + for s in c.sources + ], + } + for c in self._session_citations[-20:] # Last 20 citations + ], + } + + # ── Validation ── + + def validate_citations(self, doc_count: int) -> Dict: + """ + Validate that all cited documents exist in the document store. + + Returns a report of valid vs. orphaned citations. + """ + valid = 0 + orphaned = 0 + orphan_details = [] + + for citation in self._session_citations: + for source in citation.sources: + if source.doc_id in self._source_registry: + valid += 1 + else: + orphaned += 1 + orphan_details.append({ + "claim": citation.claim_text[:100], + "source_title": source.title, + "doc_id": source.doc_id, + }) + + return { + "valid_citations": valid, + "orphaned_citations": orphaned, + "orphan_details": orphan_details[:10], + } + + def clear_session(self): + """Reset session citations (for new query sessions).""" + self._session_citations = [] diff --git a/rag-engine/src/config.py b/rag-engine/src/config.py new file mode 100644 index 0000000..d315a07 --- /dev/null +++ b/rag-engine/src/config.py @@ -0,0 +1,86 @@ +""" +Sci-RAG Pipeline — Configuration module. +Loads settings from YAML with environment variable overrides. +""" + +import os +from pathlib import Path +from typing import Optional +import yaml + + +class Settings: + """Application settings loaded from config file + env overrides.""" + + def __init__(self, config_path: Optional[str] = None): + if config_path is None: + config_path = Path(__file__).parent.parent / "config" / "settings.yaml" + + with open(config_path) as f: + raw = yaml.safe_load(f) + + self._raw = raw + + # ── LLM ── + @property + def llm_provider(self) -> str: + return os.getenv("LLM_PROVIDER", self._raw.get("llm", {}).get("provider", "openrouter")) + + @property + def llm_model(self) -> str: + return os.getenv("LLM_MODEL", self._raw.get("llm", {}).get("model", "deepseek/deepseek-v4-flash")) + + @property + def llm_temperature(self) -> float: + return float(os.getenv("LLM_TEMPERATURE", str(self._raw.get("llm", {}).get("temperature", 0.1)))) + + @property + def llm_max_tokens(self) -> int: + return int(os.getenv("LLM_MAX_TOKENS", str(self._raw.get("llm", {}).get("max_tokens", 4096)))) + + # ── Embeddings ── + @property + def embedding_model(self) -> str: + return self._raw.get("embeddings", {}).get("model", "sentence-transformers/all-MiniLM-L6-v2") + + @property + def embedding_dimension(self) -> int: + return int(self._raw.get("embeddings", {}).get("dimension", 384)) + + # ── Document Manager ── + @property + def upload_dir(self) -> str: + return self._raw.get("document_manager", {}).get("upload_dir", "data/uploads") + + @property + def allowed_extensions(self) -> list: + return self._raw.get("document_manager", {}).get("allowed_extensions", [".pdf", ".docx", ".txt", ".md"]) + + @property + def chunk_size(self) -> int: + return int(self._raw.get("document_manager", {}).get("chunk_size", 1024)) + + @property + def chunk_overlap(self) -> int: + return int(self._raw.get("document_manager", {}).get("chunk_overlap", 200)) + + # ── Semantic Scholar ── + @property + def ss_api_base(self) -> str: + return self._raw.get("semantic_scholar", {}).get("api_base", "https://api.semanticscholar.org/v1") + + @property + def ss_max_results(self) -> int: + return int(self._raw.get("semantic_scholar", {}).get("max_results", 10)) + + # ── Server ── + @property + def host(self) -> str: + return os.getenv("HOST", self._raw.get("server", {}).get("host", "0.0.0.0")) + + @property + def port(self) -> int: + return int(os.getenv("PORT", str(self._raw.get("server", {}).get("port", 8000)))) + + +settings = Settings() diff --git a/rag-engine/src/document_manager.py b/rag-engine/src/document_manager.py new file mode 100644 index 0000000..519d50c --- /dev/null +++ b/rag-engine/src/document_manager.py @@ -0,0 +1,329 @@ +""" +Sci-RAG Pipeline — Document Manager. + +Unifies uploaded documents and Semantic Scholar references into a single +cohesive document store with unified indexing, search, and retrieval. +""" + +import hashlib +import json +import logging +import os +from pathlib import Path +from typing import Dict, List, Optional, Union +from datetime import datetime, timedelta + +import aiohttp +import arxiv + +from .config import settings + +logger = logging.getLogger(__name__) + + +class Document: + """A unified document — either uploaded or from Semantic Scholar.""" + + def __init__( + self, + doc_id: str, + title: str, + content: str, + source: str, # "upload" | "semantic_scholar" + source_url: Optional[str] = None, + authors: Optional[List[str]] = None, + metadata: Optional[Dict] = None, + cached_at: Optional[str] = None, + ): + self.doc_id = doc_id + self.title = title + self.content = content + self.source = source + self.source_url = source_url + self.authors = authors or [] + self.metadata = metadata or {} + self.cached_at = cached_at or datetime.utcnow().isoformat() + + def to_dict(self) -> Dict: + return { + "doc_id": self.doc_id, + "title": self.title, + "content_preview": self.content[:200] + ("..." if len(self.content) > 200 else ""), + "source": self.source, + "source_url": self.source_url, + "authors": self.authors, + "metadata": self.metadata, + "cached_at": self.cached_at, + } + + def __repr__(self) -> str: + return f"Document(id={self.doc_id}, title={self.title[:50]}, source={self.source})" + + +class DocumentManager: + """ + Manages the lifecycle of documents from multiple sources. + + Features: + - Ingest uploaded PDFs, DOCX, TXT files + - Search and import from Semantic Scholar via arXiv IDs or search queries + - Unified storage and indexing + - Deduplication by content hash + - Cache Semantic Scholar results to avoid redundant API calls + """ + + def __init__(self, storage_dir: str = "data/documents"): + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(parents=True, exist_ok=True) + self.manifest_path = self.storage_dir / "manifest.json" + self._documents: Dict[str, Document] = {} + self._load_manifest() + + # ── Persistence ── + + def _load_manifest(self): + """Load document manifest from disk on startup.""" + if self.manifest_path.exists(): + try: + with open(self.manifest_path) as f: + data = json.load(f) + for d in data: + doc = Document(**d) + self._documents[doc.doc_id] = doc + logger.info(f"Loaded {len(self._documents)} documents from manifest") + except Exception as e: + logger.warning(f"Failed to load manifest: {e}") + + def _save_manifest(self): + """Persist document manifest to disk.""" + data = [d.to_dict() for d in self._documents.values()] + with open(self.manifest_path, "w") as f: + json.dump(data, f, indent=2) + + # ── Document ID Generation ── + + @staticmethod + def _make_doc_id(title: str, content_hash: str) -> str: + """Generate a stable document ID from title and content hash.""" + raw = f"{title}::{content_hash}" + return hashlib.sha256(raw.encode()).hexdigest()[:16] + + @staticmethod + def _content_hash(text: str) -> str: + """Hash document content for deduplication.""" + return hashlib.md5(text.encode()).hexdigest() + + # ── Document Ingestion ── + + def add_document(self, doc: Document) -> Document: + """Add a document, deduplicating by content hash.""" + content_hash = self._content_hash(doc.content) + + # Check for duplicates + for existing in self._documents.values(): + if self._content_hash(existing.content) == content_hash: + logger.info(f"Duplicate document skipped: {doc.title}") + return existing + + # Generate stable ID if not already set + if not doc.doc_id: + doc.doc_id = self._make_doc_id(doc.title, content_hash) + + self._documents[doc.doc_id] = doc + self._save_manifest() + logger.info(f"Document added: {doc.title} [{doc.doc_id}]") + return doc + + def add_text_document( + self, + title: str, + content: str, + source: str = "upload", + source_url: Optional[str] = None, + authors: Optional[List[str]] = None, + metadata: Optional[Dict] = None, + ) -> Document: + """Add a plain-text document.""" + doc = Document( + doc_id="", + title=title, + content=content, + source=source, + source_url=source_url, + authors=authors, + metadata=metadata, + ) + return self.add_document(doc) + + def ingest_uploaded_file(self, file_path: str, title: Optional[str] = None) -> Optional[Document]: + """ + Ingest an uploaded file (PDF, DOCX, TXT, MD, TEX). + Returns the created Document or None on failure. + """ + path = Path(file_path) + if not path.exists(): + logger.error(f"File not found: {file_path}") + return None + + ext = path.suffix.lower() + if ext not in settings.allowed_extensions: + logger.warning(f"Unsupported extension: {ext}") + return None + + try: + content = self._extract_text(path) + doc_title = title or path.stem + return self.add_text_document( + title=doc_title, + content=content, + source="upload", + source_url=str(path.absolute()), + metadata={"file_name": path.name, "file_size": path.stat().st_size, "file_type": ext}, + ) + except Exception as e: + logger.error(f"Failed to ingest {file_path}: {e}") + return None + + @staticmethod + def _extract_text(path: Path) -> str: + """Extract text content from a file.""" + ext = path.suffix.lower() + + if ext == ".txt": + return path.read_text(encoding="utf-8", errors="replace") + + elif ext == ".md": + return path.read_text(encoding="utf-8", errors="replace") + + elif ext == ".pdf": + try: + import pypdf + + reader = pypdf.PdfReader(path) + text = "\n".join(page.extract_text() for page in reader.pages) + return text + except ImportError: + logger.warning("pypdf not installed — using basic extraction") + return path.read_text(encoding="utf-8", errors="replace") + + elif ext == ".docx": + try: + from docx import Document as DocxDocument + + doc = DocxDocument(path) + return "\n".join(p.text for p in doc.paragraphs) + except ImportError: + logger.warning("python-docx not installed — falling back") + return path.read_text(encoding="utf-8", errors="replace") + + elif ext == ".tex": + return path.read_text(encoding="utf-8", errors="replace") + + return path.read_text(encoding="utf-8", errors="replace") + + # ── Semantic Scholar Integration ── + + async def search_semantic_scholar(self, query: str, max_results: int = 10) -> List[Dict]: + """Search Semantic Scholar and return paper results.""" + params = { + "query": query, + "limit": max_results, + "fields": "title,authors,year,externalIds,abstract,url,citationCount", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{settings.ss_api_base}/paper/search", + params=params, + timeout=aiohttp.ClientTimeout(total=15), + ) as resp: + if resp.status != 200: + logger.warning(f"Semantic Scholar API returned {resp.status}") + return [] + + data = await resp.json() + papers = data.get("data", []) + return [ + { + "paper_id": p.get("paperId"), + "title": p.get("title", "Untitled"), + "authors": [a.get("name") for a in p.get("authors", [])], + "year": p.get("year"), + "abstract": p.get("abstract", ""), + "url": p.get("url"), + "citation_count": p.get("citationCount", 0), + "external_ids": p.get("externalIds", {}), + } + for p in papers + ] + except Exception as e: + logger.error(f"Semantic Scholar search failed: {e}") + return [] + + async def search_arxiv(self, query: str, max_results: int = 10) -> List[Dict]: + """Search arXiv and return paper results.""" + try: + search = arxiv.Search(query=query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance) + results = [] + for paper in search.results(): + results.append({ + "paper_id": paper.entry_id, + "title": paper.title, + "authors": [str(a) for a in paper.authors], + "year": paper.published.year, + "abstract": paper.summary, + "url": paper.entry_id, + "pdf_url": str(paper.pdf_url), + }) + return results + except Exception as e: + logger.error(f"arXiv search failed: {e}") + return [] + + async def import_from_semantic_scholar( + self, paper_id: str, title: str, abstract: str, authors: Optional[List[str]] = None, url: Optional[str] = None + ) -> Optional[Document]: + """Import a Semantic Scholar paper as a document.""" + if not abstract: + logger.warning(f"No abstract available for {title}") + return None + + return self.add_text_document( + title=title, + content=abstract, + source="semantic_scholar", + source_url=url or f"https://api.semanticscholar.org/v1/paper/{paper_id}", + authors=authors or [], + metadata={"paper_id": paper_id, "imported_via": "semantic_scholar"}, + ) + + # ── Query ── + + def get_document(self, doc_id: str) -> Optional[Document]: + """Retrieve a document by ID.""" + return self._documents.get(doc_id) + + def list_documents(self, source: Optional[str] = None) -> List[Dict]: + """List all documents, optionally filtered by source.""" + docs = self._documents.values() + if source: + docs = [d for d in docs if d.source == source] + return [d.to_dict() for d in sorted(docs, key=lambda d: d.cached_at, reverse=True)] + + def search_documents(self, query: str) -> List[Document]: + """Simple keyword search across documents (basic implementation).""" + query_lower = query.lower() + results = [] + for doc in self._documents.values(): + if query_lower in doc.title.lower() or query_lower in doc.content.lower()[:500]: + results.append(doc) + return results[:20] + + def all_documents(self) -> List[Document]: + """Return all documents as a list.""" + return list(self._documents.values()) + + @property + def count(self) -> int: + return len(self._documents) diff --git a/rag-engine/src/main.py b/rag-engine/src/main.py new file mode 100644 index 0000000..35d11c1 --- /dev/null +++ b/rag-engine/src/main.py @@ -0,0 +1,347 @@ +""" +Sci-RAG Pipeline — Main Entry Point. + +FastAPI server providing the REST API for the RAG pipeline, +document management, and AI access layer. + +Usage: + uvicorn src.main:app --host 0.0.0.0 --port 8000 + +Or directly: + python src/main.py +""" + +import logging +import os +import sys +from pathlib import Path +from typing import Optional, List +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, UploadFile, File, Form +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.config import settings +from src.document_manager import DocumentManager +from src.citation_engine import CitationEngine +from src.rag_pipeline import RAGPipeline +from src.ai_access_layer import AIAccessLayer, PermissionLevel + +# ── Logging ── +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +logger = logging.getLogger("sci-rag") + + +# ── Application State ── + +class AppState: + """Holds application-wide singletons.""" + def __init__(self): + self.doc_manager = DocumentManager(storage_dir="data/documents") + self.citation_engine = CitationEngine() + self.rag_pipeline = RAGPipeline( + document_manager=self.doc_manager, + citation_engine=self.citation_engine, + ) + self.access_layer = AIAccessLayer( + document_manager=self.doc_manager, + rag_pipeline=self.rag_pipeline, + ) + + +state = AppState() + + +# ── Lifespan ── + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifecycle: initialize on startup, cleanup on shutdown.""" + logger.info("Sci-RAG Pipeline starting up...") + + # Initialize the RAG pipeline (build index) + success = state.rag_pipeline.initialize() + if success: + logger.info("Pipeline initialized — ready to serve queries") + else: + logger.warning("Pipeline initialization incomplete — some features may be unavailable") + + # Create a default access token for the primary AI agent + default_token = state.access_layer.create_token( + permission=PermissionLevel.FULL_ACCESS, + expires_in_seconds=None, # Never expires + max_queries=None, # Unlimited + ) + logger.info(f"Default access token created: {default_token.token_id}") + + yield + + # Shutdown + logger.info("Sci-RAG Pipeline shutting down") + + +# ── FastAPI Application ── + +app = FastAPI( + title="Sci-RAG Pipeline", + description="Enhanced RAG pipeline for Scientific Research Workflows", + version="1.0.0", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ── Pydantic Models ── + +class QueryRequest(BaseModel): + question: str + top_k: int = 5 + + +class QueryResponse(BaseModel): + answer: str + citations: List[dict] = [] + confidence: float = 0.0 + source_count: int = 0 + + +class CitationReport(BaseModel): + total_citations: int + unique_sources: int + average_confidence: float + citations: List[dict] = [] + + +class AccessTokenResponse(BaseModel): + token_id: str + permission: str + expires_at: str + queries_limit: Optional[int] + + +class DocumentListItem(BaseModel): + doc_id: str + title: str + source: str + authors: List[str] = [] + cached_at: str + + +# ── Routes ── + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return { + "status": "healthy", + "documents_indexed": state.doc_manager.count, + "pipeline_initialized": state.rag_pipeline.is_initialized, + } + + +@app.post("/query", response_model=QueryResponse) +async def query(request: QueryRequest): + """ + Query the RAG pipeline with a scientific question. + + Returns an answer with citations from all indexed documents. + """ + result = state.rag_pipeline.query(request.question, top_k=request.top_k) + + return QueryResponse( + answer=result.get("answer", "No answer generated"), + citations=result.get("citations", []), + confidence=result.get("confidence", 0.0), + source_count=result.get("source_count", 0), + ) + + +@app.post("/documents/upload") +async def upload_document( + file: UploadFile = File(...), + title: Optional[str] = Form(None), +): + """ + Upload a document for indexing. + + Supported formats: PDF, DOCX, TXT, MD, LaTeX + """ + # Save uploaded file + upload_dir = Path(settings.upload_dir) + upload_dir.mkdir(parents=True, exist_ok=True) + + file_path = upload_dir / file.filename + content = await file.read() + file_path.write_bytes(content) + + # Ingest + doc_title = title or file.filename + doc = state.doc_manager.ingest_uploaded_file(str(file_path), title=doc_title) + + if not doc: + raise HTTPException(status_code=400, detail=f"Failed to ingest file: {file.filename}") + + # Refresh pipeline index + state.rag_pipeline.refresh_index() + + return { + "success": True, + "doc_id": doc.doc_id, + "title": doc.title, + "source": doc.source, + } + + +@app.get("/documents", response_model=List[DocumentListItem]) +async def list_documents(source: Optional[str] = None): + """List all indexed documents.""" + docs = state.doc_manager.list_documents(source=source) + return [ + DocumentListItem( + doc_id=d["doc_id"], + title=d["title"], + source=d["source"], + authors=d.get("authors", []), + cached_at=d["cached_at"], + ) + for d in docs + ] + + +@app.get("/documents/{doc_id}") +async def get_document(doc_id: str): + """Get a specific document by ID.""" + doc = state.doc_manager.get_document(doc_id) + if not doc: + raise HTTPException(status_code=404, detail="Document not found") + return doc.to_dict() + + +@app.delete("/documents/{doc_id}") +async def delete_document(doc_id: str): + """Delete a document from the index.""" + # Note: full deletion requires removing from vector store too + # This is a simplified implementation + raise HTTPException(status_code=501, detail="Not implemented yet") + + +@app.post("/references/search") +async def search_references(query: str = Form(...), max_results: int = Form(10)): + """ + Search for academic references via Semantic Scholar and arXiv. + """ + import asyncio + + ss_results, arxiv_results = await asyncio.gather( + state.doc_manager.search_semantic_scholar(query, max_results=max_results), + state.doc_manager.search_arxiv(query, max_results=max_results), + ) + + return { + "query": query, + "semantic_scholar": ss_results, + "arxiv": arxiv_results, + "total": len(ss_results) + len(arxiv_results), + } + + +@app.post("/references/import") +async def import_reference( + paper_id: str = Form(...), + title: str = Form(...), + abstract: str = Form(...), + authors: Optional[str] = Form(None), + url: Optional[str] = Form(None), +): + """Import a reference from Semantic Scholar into the document store.""" + import asyncio + + author_list = authors.split(",") if authors else [] + doc = await state.doc_manager.import_from_semantic_scholar( + paper_id=paper_id, + title=title, + abstract=abstract, + authors=author_list, + url=url, + ) + + if doc: + state.rag_pipeline.refresh_index() + return {"success": True, "doc_id": doc.doc_id, "title": doc.title} + return {"success": False, "reason": "Could not import (empty abstract?)"} + + +@app.post("/citations/report") +async def get_citation_report(): + """Get a citation report for the current session.""" + report = state.citation_engine.get_session_report() + return CitationReport( + total_citations=report["total_citations"], + unique_sources=report["unique_sources"], + average_confidence=report["average_confidence"], + citations=report["citations"], + ) + + +@app.post("/access/token") +async def create_access_token( + permission: str = Form("read_query"), + expires_in: Optional[int] = Form(3600), + max_queries: Optional[int] = Form(100), +): + """Create a new access token for AI agents.""" + perm_map = { + "read_only": PermissionLevel.READ_ONLY, + "read_query": PermissionLevel.READ_QUERY, + "full_access": PermissionLevel.FULL_ACCESS, + } + perm = perm_map.get(permission, PermissionLevel.READ_QUERY) + + token = state.access_layer.create_token( + permission=perm, + expires_in_seconds=expires_in, + max_queries=max_queries, + ) + + return { + "token_id": token.token_id, + "permission": token.permission.value, + "expires_at": ( + token.expires_at if token.expires_at else "never" + ), + "queries_limit": token.max_queries, + } + + +# ── Direct Execution ── + +def main(): + """Run the server directly.""" + import uvicorn + + logger.info("Starting Sci-RAG Pipeline server...") + uvicorn.run( + "src.main:app", + host=settings.host, + port=settings.port, + reload=False, + log_level="info", + ) + + +if __name__ == "__main__": + main() diff --git a/rag-engine/src/rag_pipeline.py b/rag-engine/src/rag_pipeline.py new file mode 100644 index 0000000..915ff03 --- /dev/null +++ b/rag-engine/src/rag_pipeline.py @@ -0,0 +1,353 @@ +""" +Sci-RAG Pipeline — Core RAG Pipeline. + +Integrates Llama Index for hierarchical document retrieval and LLM-powered +synthesis, optimized for scientific and research workflows with proper +citation tracking. +""" + +import logging +from pathlib import Path +from typing import Dict, List, Optional, Any + +from .config import settings +from .document_manager import Document, DocumentManager +from .citation_engine import Citation, CitationEngine, Source + +logger = logging.getLogger(__name__) + + +class RAGPipeline: + """ + Core RAG pipeline using Llama Index for retrieval and generation. + + Features: + - Hierarchical node parsing for scientific documents + - Embedding-based retrieval with hybrid search + - LLM-powered synthesis with citation tracking + - Extensible: swap any component (embeddings, LLM, vector store) + - Performance optimized with caching and async processing + """ + + def __init__( + self, + document_manager: DocumentManager, + citation_engine: CitationEngine, + llm: Optional[Any] = None, + embed_model: Optional[Any] = None, + index: Optional[Any] = None, + ): + self.doc_manager = document_manager + self.citation_engine = citation_engine + self._llm = llm + self._embed_model = embed_model + self._index = index + self._vector_store = None + self._initialized = False + + # ── Initialization ── + + def initialize(self) -> bool: + """ + Initialize the pipeline — set up Llama Index, embeddings, and vector store. + + This is called once at server startup and builds the index from + all currently managed documents. + + Returns True if initialization succeeded. + """ + try: + logger.info("Initializing RAG pipeline...") + self._setup_llm() + self._setup_embeddings() + self._setup_vector_store() + self._build_index() + self._initialized = True + logger.info("RAG pipeline initialized successfully") + return True + except Exception as e: + logger.error(f"Pipeline initialization failed: {e}") + return False + + def _setup_llm(self): + """Configure the LLM (OpenRouter via Llama Index).""" + try: + from llama_index.llms.openrouter import OpenRouter + + self._llm = OpenRouter( + model=settings.llm_model, + temperature=settings.llm_temperature, + max_tokens=settings.llm_max_tokens, + ) + logger.info(f"LLM configured: {settings.llm_model}") + except ImportError: + logger.warning("OpenRouter LLM not available — using OpenAI-compatible fallback") + from llama_index.llms.openai import OpenAI + + self._llm = OpenAI( + model="gpt-3.5-turbo", + temperature=settings.llm_temperature, + api_key=settings.openai_api_key if hasattr(settings, "openai_api_key") else None, + ) + + def _setup_embeddings(self): + """Configure the embedding model.""" + try: + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + + self._embed_model = HuggingFaceEmbedding( + model_name=settings.embedding_model, + max_length=512, + embed_batch_size=settings.embedding_dimension, + ) + logger.info(f"Embeddings configured: {settings.embedding_model}") + except ImportError: + logger.warning("HuggingFace embeddings not available — using OpenAI fallback") + from llama_index.embeddings.openai import OpenAIEmbedding + + self._embed_model = OpenAIEmbedding() + + def _setup_vector_store(self): + """Configure the vector store (ChromaDB).""" + persist_dir = Path(self._raw_settings("vector_store", "persist_directory", "data/chroma_db")) + persist_dir.mkdir(parents=True, exist_ok=True) + + try: + import chromadb + from llama_index.vector_stores.chroma import ChromaVectorStore + + db = chromadb.PersistentClient(path=str(persist_dir)) + collection = db.get_or_create_collection( + name=self._raw_settings("vector_store", "collection_name", "scientific_docs") + ) + self._vector_store = ChromaVectorStore(chroma_collection=collection) + logger.info(f"Vector store configured: {persist_dir}") + except ImportError: + logger.warning("ChromaDB not available — using in-memory store") + from llama_index.vector_stores.simple import SimpleVectorStore + + self._vector_store = SimpleVectorStore() + + def _raw_settings(self, *keys, default=None): + """Get a nested setting from the config raw dict.""" + val = settings._raw + for key in keys: + if isinstance(val, dict): + val = val.get(key) + else: + return default + return val if val is not None else default + + def _build_index(self): + """Build the Llama Index from all managed documents.""" + from llama_index.core import VectorStoreIndex, StorageContext, Document as LlamaDocument + + all_docs = self.doc_manager.all_documents() + if not all_docs: + logger.info("No documents to index — creating empty index") + self._index = VectorStoreIndex.from_documents( + [], + embed_model=self._embed_model, + vector_store=self._vector_store, + ) + return + + # Convert our Document objects to Llama Index Documents + llama_docs = [] + for doc in all_docs: + llama_doc = LlamaDocument( + text=doc.content, + metadata={ + "doc_id": doc.doc_id, + "title": doc.title, + "source": doc.source, + "source_url": doc.source_url or "", + "authors": ", ".join(doc.authors) if doc.authors else "", + }, + ) + llama_docs.append(llama_doc) + + logger.info(f"Building index from {len(llama_docs)} documents...") + storage_context = StorageContext.from_defaults(vector_store=self._vector_store) + self._index = VectorStoreIndex.from_documents( + llama_docs, + embed_model=self._embed_model, + storage_context=storage_context, + show_progress=True, + ) + logger.info(f"Index built with {len(all_docs)} documents") + + # ── Query ── + + def query(self, question: str, top_k: int = 5) -> Dict: + """ + Query the RAG pipeline with a question. + + Args: + question: The user's question. + top_k: Number of document chunks to retrieve. + + Returns: + Dict with 'answer', 'citations', and 'confidence'. + """ + if not self._initialized: + success = self.initialize() + if not success: + return { + "answer": "Pipeline failed to initialize. Please check the logs.", + "citations": [], + "confidence": 0.0, + "error": "initialization_failed", + } + + if not self._index: + return { + "answer": "No documents indexed. Please upload documents or import references first.", + "citations": [], + "confidence": 0.0, + "source_count": 0, + } + + try: + from llama_index.core.retrievers import VectorIndexRetriever + from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.postprocessor import SimilarityPostprocessor + + # Build retriever with the correct top_k + retriever = VectorIndexRetriever( + index=self._index, + similarity_top_k=top_k, + ) + + # Build query engine with citation tracking + query_engine = RetrieverQueryEngine.from_args( + retriever=retriever, + llm=self._llm, + node_postprocessors=[ + SimilarityPostprocessor(similarity_cutoff=0.5), + ], + ) + + # Execute the query + response = query_engine.query(question) + + # Extract sources for citation tracking + sources = [] + document_lookup = {} + + for node in response.source_nodes: + doc_id = node.metadata.get("doc_id", "unknown") + title = node.metadata.get("title", "Untitled") + source_type = node.metadata.get("source", "upload") + + # Build source record for citation engine + sources.append((doc_id, node.text[:200], node.score if hasattr(node, 'score') else 0.7)) + + # Build lookup for citation engine + if doc_id not in document_lookup: + doc = self.doc_manager.get_document(doc_id) + if doc: + document_lookup[doc_id] = { + "title": doc.title, + "authors": doc.authors, + "source_url": doc.source_url, + "source": doc.source, + "metadata": doc.metadata, + } + else: + document_lookup[doc_id] = { + "title": title, + "authors": [], + "source_url": node.metadata.get("source_url"), + "source": source_type, + "metadata": {}, + } + + # Record citations + citation = self.citation_engine.record_claim( + claim_text=question, + sources=sources, + document_lookup=document_lookup, + ) + + return { + "answer": str(response), + "citations": [ + { + "doc_id": s.doc_id, + "title": s.title[:100], + "relevance": s.relevance_score, + "source_type": s.source_type, + } + for s in citation.sources + ], + "confidence": citation.confidence, + "source_count": len(citation.sources), + "source_nodes": [ + { + "doc_id": n.metadata.get("doc_id", "unknown"), + "title": n.metadata.get("title", "Untitled"), + "score": n.score if hasattr(n, 'score') else None, + "excerpt": n.text[:300], + } + for n in response.source_nodes + ], + } + + except Exception as e: + logger.error(f"Query failed: {e}") + return { + "answer": f"Query processing failed: {str(e)}", + "citations": [], + "confidence": 0.0, + "error": str(e), + } + + def query_stream(self, question: str): + """ + Query with streaming response. + + Yields answer chunks as they're generated. + """ + if not self._initialized: + self.initialize() + + from llama_index.core.query_engine import RetrieverQueryEngine + from llama_index.core.retrievers import VectorIndexRetriever + from llama_index.core.postprocessor import SimilarityPostprocessor + + retriever = VectorIndexRetriever( + index=self._index, + similarity_top_k=5, + ) + + query_engine = RetrieverQueryEngine.from_args( + retriever=retriever, + llm=self._llm, + node_postprocessors=[ + SimilarityPostprocessor(similarity_cutoff=0.5), + ], + streaming=True, + ) + + response = query_engine.query(question) + + for chunk in response.response_gen: + yield chunk + + # ── Index Maintenance ── + + def refresh_index(self) -> bool: + """Rebuild the index from scratch. Call after adding documents.""" + try: + logger.info("Refreshing index...") + self._build_index() + logger.info("Index refreshed successfully") + return True + except Exception as e: + logger.error(f"Index refresh failed: {e}") + return False + + @property + def is_initialized(self) -> bool: + return self._initialized diff --git a/rag-engine/tests/__pycache__/test_pipeline.cpython-313-pytest-9.0.3.pyc b/rag-engine/tests/__pycache__/test_pipeline.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..eb46863 Binary files /dev/null and b/rag-engine/tests/__pycache__/test_pipeline.cpython-313-pytest-9.0.3.pyc differ diff --git a/rag-engine/tests/test_pipeline.py b/rag-engine/tests/test_pipeline.py new file mode 100644 index 0000000..f3e9a2f --- /dev/null +++ b/rag-engine/tests/test_pipeline.py @@ -0,0 +1,175 @@ +""" +Sci-RAG Pipeline — Test Suite. + +Tests cover: +- Document Manager (ingestion, dedup, Semantic Scholar) +- Citation Engine (recording, formatting, validation) +- RAG Pipeline (query, indexing) +""" + +import asyncio +import json +import os +import sys +import tempfile +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import pytest + +from src.document_manager import DocumentManager, Document +from src.citation_engine import CitationEngine, Citation, Source +from src.config import settings + + +# ═════════════════════════════════════════════ +# Document Manager Tests +# ═════════════════════════════════════════════ + +class TestDocumentManager: + def setup_method(self): + self.tmp_dir = tempfile.mkdtemp() + self.manager = DocumentManager(storage_dir=self.tmp_dir) + + def test_add_text_document(self): + doc = self.manager.add_text_document( + title="Test Paper", + content="This is a test scientific paper about RAG pipelines.", + source="upload", + authors=["Test Author"], + ) + assert doc.doc_id is not None + assert doc.title == "Test Paper" + assert len(doc.doc_id) == 16 + + def test_deduplication(self): + doc1 = self.manager.add_text_document( + title="Same Paper", + content="Identical content here", + ) + doc2 = self.manager.add_text_document( + title="Same Paper", + content="Identical content here", + ) + assert doc1.doc_id == doc2.doc_id + + def test_list_documents(self): + self.manager.add_text_document(title="Doc 1", content="Content 1") + self.manager.add_text_document(title="Doc 2", content="Content 2") + docs = self.manager.list_documents() + assert len(docs) == 2 + + def test_search_documents(self): + self.manager.add_text_document(title="RAG Pipeline", content="This paper discusses RAG systems for scientific research.") + self.manager.add_text_document(title="LLM Basics", content="Introduction to language models.") + results = self.manager.search_documents("RAG") + assert len(results) >= 1 + assert "RAG" in results[0].title + + def test_count(self): + assert self.manager.count == 0 + self.manager.add_text_document(title="Doc", content="Content") + assert self.manager.count == 1 + + def test_ingest_text_file(self): + tmp_file = Path(self.tmp_dir) / "test_doc.txt" + tmp_file.write_text("This is a test document content for ingestion testing.") + doc = self.manager.ingest_uploaded_file(str(tmp_file)) + assert doc is not None + assert doc.source == "upload" + + def test_ingest_invalid_extension(self): + tmp_file = Path(self.tmp_dir) / "test.exe" + tmp_file.write_text("bad") + doc = self.manager.ingest_uploaded_file(str(tmp_file)) + assert doc is None + + def test_get_document(self): + doc = self.manager.add_text_document(title="Get Me", content="Findable content") + retrieved = self.manager.get_document(doc.doc_id) + assert retrieved is not None + assert retrieved.title == "Get Me" + + def test_get_nonexistent(self): + assert self.manager.get_document("nonexistent") is None + + +# ═════════════════════════════════════════════ +# Citation Engine Tests +# ═════════════════════════════════════════════ + +class TestCitationEngine: + def setup_method(self): + self.engine = CitationEngine() + + def test_record_claim(self): + doc_lookup = { + "doc1": {"title": "Paper 1", "authors": ["Alice"], "source_url": "http://example.com/1", "source": "semantic_scholar", "metadata": {}}, + } + citation = self.engine.record_claim( + claim_text="What is RAG?", + sources=[("doc1", "RAG stands for Retrieval-Augmented Generation...", 0.95)], + document_lookup=doc_lookup, + ) + assert citation.claim_text == "What is RAG?" + assert len(citation.sources) == 1 + assert citation.confidence > 0.9 + + def test_multiple_sources(self): + doc_lookup = { + "doc1": {"title": "Paper A", "authors": [], "source": "upload", "metadata": {}}, + "doc2": {"title": "Paper B", "authors": [], "source": "semantic_scholar", "metadata": {}}, + } + citation = self.engine.record_claim( + claim_text="Multiple sources test", + sources=[("doc1", "Source A content...", 0.8), ("doc2", "Source B content...", 0.9)], + document_lookup=doc_lookup, + ) + assert len(citation.sources) == 2 + assert citation.confidence == 0.9 # max of 0.8 and 0.9 + + def test_get_session_report(self): + doc_lookup = { + "doc1": {"title": "Paper", "authors": [], "source": "upload", "metadata": {}}, + } + self.engine.record_claim("Claim 1", [("doc1", "Content", 0.8)], doc_lookup) + self.engine.record_claim("Claim 2", [("doc1", "More content", 0.7)], doc_lookup) + + report = self.engine.get_session_report() + assert report["total_citations"] == 2 + assert report["unique_sources"] == 1 + + def test_validate_citations(self): + doc_lookup = { + "doc1": {"title": "Paper", "authors": [], "source": "upload", "metadata": {}}, + } + self.engine.record_claim("Valid claim", [("doc1", "Content", 0.8)], doc_lookup) + result = self.engine.validate_citations(doc_count=1) + assert result["valid_citations"] + result["orphaned_citations"] > 0 + + def test_clear_session(self): + doc_lookup = { + "doc1": {"title": "Paper", "authors": [], "source": "upload", "metadata": {}}, + } + self.engine.record_claim("Claim", [("doc1", "Content", 0.8)], doc_lookup) + assert len(self.engine._session_citations) == 1 + self.engine.clear_session() + assert len(self.engine._session_citations) == 0 + + +# ═════════════════════════════════════════════ +# Configuration Tests +# ═════════════════════════════════════════════ + +class TestConfig: + def test_settings_load(self): + assert settings.llm_provider == "openrouter" + assert settings.llm_model == "deepseek/deepseek-v4-flash" + assert settings.embedding_model == "sentence-transformers/all-MiniLM-L6-v2" + + def test_embedding_dimension(self): + assert settings.embedding_dimension == 384 + + def test_chunk_size(self): + assert settings.chunk_size == 1024 diff --git a/ui/pages/api/inject-documents.ts b/ui/pages/api/inject-documents.ts index 532a635..514b43e 100644 --- a/ui/pages/api/inject-documents.ts +++ b/ui/pages/api/inject-documents.ts @@ -1,5 +1,4 @@ import type { NextApiRequest, NextApiResponse } from 'next'; - import { ChromaClient, TransformersEmbeddingFunction } from 'chromadb'; import { IncomingForm } from 'formidable'; import { PDFLoader } from 'langchain/document_loaders/fs/pdf'; @@ -14,6 +13,32 @@ export const config = { }, }; +/** + * Try to upload the document to the Sci-RAG Engine for Llama Index processing. + * Falls back to legacy Chroma-only ingestion if unavailable. + */ +async function uploadToRagEngine(filePath: string, fileName: string): Promise { + const ragHost = process.env.RAG_ENGINE_HOST || 'http://rag-engine:8000'; + + try { + const fs = await import('fs'); + const buffer = fs.readFileSync(filePath); + const blob = new Blob([buffer]); + const formData = new FormData(); + formData.append('file', blob, fileName); + + const response = await fetch(`${ragHost}/documents/upload`, { + method: 'POST', + body: formData, + signal: AbortSignal.timeout(60000), + }); + return response.ok; + } catch (error) { + console.warn('RAG engine unavailable, using legacy Chroma ingestion:', error); + return false; + } +} + export default async function handler( req: NextApiRequest, res: NextApiResponse, @@ -25,53 +50,74 @@ export default async function handler( const form = new IncomingForm(); form.parse(req, async (err, fields, files) => { - if (err) { - return res.status(400).json({ error: 'Failed to upload file' }); - } - - const client = new ChromaClient({ - path: process.env.CHROMA_PATH || 'http://chroma-server:8000', - }); - - const loader = new PDFLoader(files.pdf[0].filepath); - - const originalDocs = await loader.load(); - - console.log(JSON.stringify(originalDocs)); - - - const splitter = new RecursiveCharacterTextSplitter({ - chunkSize: 500, - chunkOverlap: 100, - }); - - const docs = await splitter.splitDocuments(originalDocs); + try { + if (err) { + return res.status(400).json({ error: 'Failed to upload file' }); + } + + const file = (files.file || files.pdf); + const uploadedFile = file instanceof Array ? file[0] : file; + if (!uploadedFile) { + return res.status(400).json({ error: 'No file provided' }); + } + + const filePath = uploadedFile.filepath; + const fileName = uploadedFile.originalFilename || 'document.pdf'; + + // Step 1: Try the Sci-RAG Engine (Llama Index + Smart Citations) + const ragSuccess = await uploadToRagEngine(filePath, fileName); + if (ragSuccess) { + return res.status(200).json({ + success: true, + method: 'rag-engine', + message: `Document "${fileName}" indexed with Llama Index via rag-engine`, + }); + } + + // Step 2: Fallback to legacy ChromaDB ingestion + const client = new ChromaClient({ + path: process.env.CHROMA_PATH || 'http://chroma-server:8000', + }); + + const loader = new PDFLoader(filePath); + + const originalDocs = await loader.load(); + + const splitter = new RecursiveCharacterTextSplitter({ + chunkSize: 500, + chunkOverlap: 100, + }); + + const docs = await splitter.splitDocuments(originalDocs); - // Process the documents and perform other logic - const { ids, metadatas, documentContents } = processDocuments(docs); - - const embedder = new TransformersEmbeddingFunction(); - const collection = await client.getOrCreateCollection({ - name: 'default-collection', - embeddingFunction: embedder, - }); - - await collection.add({ - ids, - metadatas, - documents: documentContents, - }); - - res.status(200).json({ - message: 'Documents processed successfully', - documentCount: ids.length, - }); + const { ids, metadatas, documentContents } = processDocuments(docs); + + const embedder = new TransformersEmbeddingFunction(); + const collection = await client.getOrCreateCollection({ + name: 'default-collection', + embeddingFunction: embedder, + }); + + await collection.add({ + ids, + metadatas, + documents: documentContents, + }); + + res.status(200).json({ + success: true, + method: 'chroma-legacy', + message: 'Documents processed successfully', + documentCount: ids.length, + }); + } catch (parseError) { + console.error(parseError); + res.status(500).json({ error: 'An error occurred while processing the documents' }); + } }); } catch (error) { console.error(error); - res - .status(500) - .json({ message: 'An error occurred while processing the documents' }); + res.status(500).json({ error: 'Server error' }); } } @@ -81,24 +127,21 @@ function processDocuments(docs: any) { const documentContents = []; for (const document of docs) { - // Generate an ID for each document, or use some existing unique identifier const id = uuidv4(); ids.push(id); const fallbackTitle = path.basename(document.metadata.source); - const titleFromMetadata = document.metadata.pdf.info.Title; + const titleFromMetadata = document.metadata.pdf?.info?.Title; const title = titleFromMetadata && titleFromMetadata.length > 0 ? titleFromMetadata : fallbackTitle; - const metadata = { title: title, - page: document.metadata.loc.pageNumber, // Define this function to extract chapter info - source: document.metadata.source, // Define this function to extract verse info + page: document.metadata.loc?.pageNumber || 1, + source: document.metadata.source, }; metadatas.push(metadata); - // Add the page content to the documents array documentContents.push(document.pageContent); } diff --git a/ui/pages/api/rag-chat.ts b/ui/pages/api/rag-chat.ts index ce84d67..18903ee 100644 --- a/ui/pages/api/rag-chat.ts +++ b/ui/pages/api/rag-chat.ts @@ -1,6 +1,5 @@ import { DEFAULT_SYSTEM_PROMPT, DEFAULT_TEMPERATURE } from '@/utils/app/const'; import { OpenAIError, OpenAIStream } from '@/utils/server'; -import { codeBlock, oneLine } from 'common-tags' import { ChatBody, Message } from '@/types/chat'; @@ -14,39 +13,36 @@ export const config = { runtime: 'edge', }; -// Function to fetch and format documents -async function fetchAndFormatDocuments(lastMessageContent: string) { +/** + * Query the Sci-RAG Engine for scientific document retrieval and citation. + * Falls back to the legacy document fetch if the rag-engine is unavailable. + */ +async function queryRagEngine(question: string): Promise<{ + answer: string; + citations: Array<{ title: string; relevance: number; source_type: string }>; + confidence: number; +}> { + const ragHost = process.env.RAG_ENGINE_HOST || 'http://rag-engine:8000'; + try { - console.log("fetching documents") - const response = await fetch('http://localhost:3000/api/fetch-documents', { + const response = await fetch(`${ragHost}/query`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ input: lastMessageContent }), + body: JSON.stringify({ question, top_k: 6 }), + signal: AbortSignal.timeout(15000), }); - + if (!response.ok) { - throw new Error(`Error fetching documents: ${response.statusText}`); + throw new Error(`RAG engine returned ${response.status}`); } - const data = await response.json(); - const result = data.metadatas[0].map((metadata: any, index: number) => { - return `Source ${index + 1}) Title: ${metadata.title}, Page: ${metadata.page}, Content: ${data.documents[0][index]}\n`; - }).join(''); - - console.log(result); - - return result; - + return await response.json(); } catch (error) { - console.error('Error fetching and formatting documents:', error); - throw error; // You may want to throw a more specific error object here + console.warn('RAG engine unavailable, using legacy document fetch:', error); + return { answer: '', citations: [], confidence: 0 }; } } - - - - const handler = async (req: Request): Promise => { try { @@ -60,89 +56,61 @@ const handler = async (req: Request): Promise => { tiktokenModel.pat_str, ); - let promptToSend = codeBlock` - ${oneLine` - You are a very enthusiastic AI assistant who loves - to help people! Given the following information from - relevant documentation, answer the user's question using - only that information, outputted in markdown format. - `} - - ${oneLine` - If you are unsure - and the answer is not explicitly written in the documentation, say - "Sorry, I don't know how to help with that." - `} - - ${oneLine` - Always include citations from the documentation. - `} - `; + const lastMessage = messages[messages.length - 1]; + + // Query the Sci-RAG Engine for document-enhanced answers + const ragResult = await queryRagEngine(lastMessage.content); + let promptToSend = prompt; if (!promptToSend) { promptToSend = DEFAULT_SYSTEM_PROMPT; } - const lastMessage = messages[messages.length - 1]; - - const relevantDocuments = await fetchAndFormatDocuments(lastMessage.content); - let temperatureToUse = temperature; if (temperatureToUse == null) { temperatureToUse = DEFAULT_TEMPERATURE; } const prompt_tokens = encoding.encode(promptToSend); - let tokenCount = prompt_tokens.length; let messagesToSend: Message[] = []; - encoding.free(); - console.log(model, promptToSend, temperatureToUse, key, messagesToSend); - - - messagesToSend = [ - { - role: "user", - content: codeBlock` - Here is the relevant documentation: - ${relevantDocuments} - `, - }, - { - role: "user", - content: codeBlock` - ${oneLine` - Answer my next question using only the above documentation. - You must also follow the below rules when answering: - `} - ${oneLine` - - Do not make up answers that are not provided in the documentation. - `} - ${oneLine` - - If you are unsure and the answer is not explicitly written - in the documentation context, say - "Sorry, I don't know how to help with that." - `} - ${oneLine` - - Prefer splitting your response into multiple paragraphs. - `} - ${oneLine` - - Output as markdown with citations based on the documentation. - `} - `, - }, - { - role: "user", - content: codeBlock` - Here is my question: - ${oneLine`${lastMessage.content}`} - `, - }, - ] - + // If we got a RAG answer with citations, use it directly + if (ragResult.answer && ragResult.citations.length > 0) { + // Build a citation appendix + const citationAppendix = ragResult.citations + .map((c, i) => `[${i + 1}] ${c.title} (${c.source_type}, confidence: ${(c.relevance * 100).toFixed(0)}%)`) + .join('\n'); + + messagesToSend = [ + { + role: 'system', + content: `You are a scientific AI assistant. Use the retrieved information below to answer the user's question. Always cite your sources. + +Retrieved Information: +${ragResult.answer} + +Citations: +${citationAppendix} + +Overall confidence: ${(ragResult.confidence * 100).toFixed(0)}%`, + }, + { + role: 'user', + content: lastMessage.content, + }, + ]; + } else { + // Fallback: use direct LLM response without RAG context + messagesToSend = [ + { + role: 'user', + content: lastMessage.content, + }, + ]; + } const stream = await OpenAIStream( model,