-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathraptor_api.py
More file actions
2151 lines (1807 loc) · 78.9 KB
/
raptor_api.py
File metadata and controls
2151 lines (1807 loc) · 78.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import random
import time
import multiprocessing
import logging
import re
import numpy as np
import pandas as pd
import umap
import warnings
import torch
import tiktoken
import requests
import os
import hmac
import threading
import gc
from typing import List, Dict, Optional, Tuple, Union, Any
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from pydantic import BaseModel, Field, ConfigDict
from pydantic_settings import BaseSettings
from numba.core.errors import NumbaWarning
from sklearn.mixture import GaussianMixture
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from tqdm import tqdm
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
from functools import lru_cache
# Define Settings class for configuration
class Settings(BaseSettings):
# Ollama settings
ollama_base_url: str = "http://localhost:11434"
ollama_api_key: str = ""
# Model settings
llm_model: str = "gemma3:4b"
embedder_model: str = "BAAI/bge-m3"
# Generation settings
temperature: float = 0.1
context_window: int = 16384
# Ollama inference optimization settings
ollama_num_thread: int = 8 # CPU threads for inference
ollama_num_gpu: int = 99 # GPU layers (99 = all on GPU)
ollama_num_predict: int = 512 # Max output tokens
# Parallel LLM settings
llm_max_workers: int = 2 # Concurrent LLM requests
# Retry settings
llm_max_retries: int = 3
llm_base_delay: float = 1.0
llm_timeout: int = 600
# Performance settings
random_seed: int = 224
max_workers: int = max(
1, min(multiprocessing.cpu_count() - 1, int(multiprocessing.cpu_count() * 0.75))
)
# File settings
allowed_extensions: set = {"json"}
class Config:
case_sensitive = False
extra = "ignore"
@property
def ollama_api_generate_url(self) -> str:
"""URL for the Ollama generate API endpoint"""
# Ensure no space in URL by using urljoin or careful string manipulation
base = self.ollama_base_url.rstrip("/")
return f"{base}/api/generate"
@property
def ollama_api_tags_url(self) -> str:
"""URL for the Ollama tags API endpoint"""
# Ensure no space in URL by using urljoin or careful string manipulation
base = self.ollama_base_url.rstrip("/")
return f"{base}/api/tags"
@property
def ollama_api_pull_url(self) -> str:
"""URL for the Ollama pull API endpoint"""
# Ensure no space in URL by using urljoin or careful string manipulation
base = self.ollama_base_url.rstrip("/")
return f"{base}/api/pull"
# Create settings instance
@lru_cache()
def get_settings():
"""
Returns a cached settings instance.
Using lru_cache ensures settings are loaded only once during the application lifecycle.
"""
return Settings()
# Default instructions for summarization (user-customizable, no {chunk} placeholder needed)
DEFAULT_INSTRUCTIONS = """You are an expert analyst. Summarize the text below in a single, high-density narrative paragraph following these rules:
1. CORE OBJECTIVE:
- Capture the primary theme and all essential facts, numbers, and entities.
- Identify and preserve all unique entities (proper nouns, specific locations) with exact naming.
2. ACCURACY & ATTRIBUTION (CRITICAL):
- STRICTLY derive all information from the provided text. Do not add external knowledge.
- VERIFY ATTRIBUTION: Ensure quotes, actions, and events are assigned to the correct entity/character. Do not shift actions between characters.
- NO HALLUCINATIONS: Do not invent causal links or details to "smooth out" the narrative. If the text is disjointed, summarize it as disjointed.
- Use only information explicitly stated in the source text.
3. STYLE & FORMATTING:
- Maintain a strict 3rd-person objective perspective.
- Write ONE cohesive paragraph. No bullet points or lists.
- Start IMMEDIATELY with the content—no introductions or meta-commentary.
- Use the EXACT SAME LANGUAGE as the original text."""
# Fixed chunk wrapper (system-managed, never exposed to user)
CHUNK_WRAPPER = """
Text:
<text>
{chunk}
</text>
Summary:"""
# Keep PROMPT_TEMPLATE for backward compatibility (combines both parts)
PROMPT_TEMPLATE = DEFAULT_INSTRUCTIONS + CHUNK_WRAPPER
def build_prompt(chunk: str, instructions: str = None) -> str:
"""
Builds the final prompt for Ollama by combining instructions with chunk wrapper.
Args:
chunk: The text to summarize
instructions: Custom instructions (optional). If None, uses DEFAULT_INSTRUCTIONS.
User does NOT need to include {chunk} - it's added automatically.
Returns:
The complete formatted prompt ready for Ollama.
"""
instr = instructions.strip() if instructions else DEFAULT_INSTRUCTIONS
return instr + CHUNK_WRAPPER.format(chunk=chunk)
def _get_ollama_headers() -> dict:
"""Return Authorization header for Ollama requests if OLLAMA_API_KEY is set.
Returns an empty dict when the key is unset, so callers can always do:
headers = {"Content-Type": "application/json", **_get_ollama_headers()}
"""
key = get_settings().ollama_api_key.strip()
if key:
return {"Authorization": f"Bearer {key}"}
return {}
def check_ollama_server_reachable(
ollama_base_url: str = None, timeout: int = 5, verbose: bool = False
):
"""
Check if the Ollama server is reachable before attempting any model operations.
Args:
ollama_base_url: Optional override of the base Ollama URL. If None, uses the value from settings.
timeout: Timeout in seconds for the connection attempt
verbose: Whether to log detailed messages about server connectivity
Returns:
bool: True if the server is reachable, False otherwise
"""
settings = get_settings()
base_url = (ollama_base_url or settings.ollama_base_url).rstrip("/")
# Try to connect to the root endpoint - Ollama shows "Ollama is running" on root
try:
# First try the root endpoint
response = requests.get(f"{base_url}", headers=_get_ollama_headers(), timeout=timeout)
if response.status_code == 200:
if verbose:
logger.info(f"Ollama server at {base_url} is reachable")
return True
# If that fails, try the /api/tags endpoint which should exist
response = requests.get(f"{base_url}/api/tags", headers=_get_ollama_headers(), timeout=timeout)
if response.status_code == 200:
if verbose:
logger.info(f"Ollama server at {base_url} is reachable via /api/tags")
return True
# If both fail but server responds, log the error
logger.error(
f"Ollama server at {base_url} returned unexpected status code {response.status_code}"
)
return False
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama server at {base_url}. Is Ollama running?"
)
return False
except requests.exceptions.Timeout:
logger.error(f"Timeout connecting to Ollama server at {base_url}")
return False
except Exception as e:
logger.error(f"Error checking Ollama server availability: {e}")
return False
def check_ollama_model(model_name: str, ollama_base_url: str = None):
"""
Checks if a specific Ollama model is available locally via the API.
Args:
model_name: The name of the model to check (e.g., "llama3:latest").
ollama_base_url: The base URL of the Ollama API.
Returns:
True if the model is available locally, False otherwise.
"""
settings = get_settings()
# Use provided base URL or get from settings
base_url = (ollama_base_url or settings.ollama_base_url).rstrip("/")
api_url = f"{base_url}/api/tags"
try:
response = requests.get(api_url, headers=_get_ollama_headers(), timeout=10)
response.raise_for_status()
data = response.json()
models = data.get("models", [])
if not isinstance(models, list):
logger.error(
f"Unexpected format from Ollama API /api/tags. Expected a list under 'models'. Response: {data}"
)
return False
# Check if any model name exactly matches our model_name
for model in models:
if isinstance(model, dict) and model.get("name") == model_name:
return True
return False
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama API at {ollama_base_url or settings.ollama_base_url}. Is Ollama running?"
)
return False
except requests.exceptions.Timeout:
logger.error(
f"Timeout connecting to Ollama API at {ollama_base_url or settings.ollama_base_url}"
)
return False
except requests.exceptions.RequestException as e:
logger.error(f"Error checking for Ollama model: {e}")
return False
def ensure_ollama_model(model_name: str, fallback_model: str = None) -> str:
"""
Ensures an Ollama model is available locally, attempting to pull it if not.
First checks if the Ollama server is reachable before attempting any operations.
If pulling fails and a fallback model is provided, it will verify the fallback is available.
Args:
model_name: The name of the model to ensure is available
fallback_model: Optional fallback model to use if the requested model can't be pulled
Returns:
The name of the model that is available to use (either the requested or fallback model)
"""
# First check if the Ollama server is reachable at all
if not check_ollama_server_reachable(verbose=False):
logger.warning(
f"Ollama server is not reachable. Cannot ensure model '{model_name}' is available."
)
# Return the requested model name even though we can't verify it
# This allows the application to start even if Ollama is not available
return model_name
# Check if model exists
logger.info(f"Checking if Ollama model '{model_name}' is available locally...")
if not check_ollama_model(model_name):
logger.warning(f"Model '{model_name}' not found locally. Attempting to pull...")
# Try to pull the model
if pull_ollama_model(model_name, stream=False):
logger.info(f"Successfully pulled model '{model_name}'")
return model_name
# If pull fails and we have a fallback model
if fallback_model and fallback_model != model_name:
logger.warning(
f"Failed to pull model '{model_name}'. Trying fallback model '{fallback_model}'"
)
# Check if fallback model exists
if not check_ollama_model(fallback_model):
logger.warning(
f"Fallback model '{fallback_model}' not found locally. Attempting to pull..."
)
# Try to pull the fallback model
if pull_ollama_model(fallback_model, stream=False):
logger.info(
f"Successfully pulled fallback model '{fallback_model}'"
)
return fallback_model
else:
logger.error(
f"Failed to pull fallback model '{fallback_model}'. Processing may fail."
)
# Return the fallback model name anyway, as that's our best option
return fallback_model
else:
logger.info(f"Fallback model '{fallback_model}' is available locally")
return fallback_model
else:
# No fallback provided or fallback is the same as requested model
logger.error(
f"Failed to pull model '{model_name}' and no valid fallback available. Processing may fail."
)
return model_name
else:
logger.info(f"Model '{model_name}' is available locally")
return model_name
def pull_ollama_model(
model_name: str,
ollama_base_url: str = None,
stream: bool = False,
):
"""
Triggers Ollama to pull a model using the API.
Args:
model_name: The name of the model to pull (e.g., "llama3:latest").
ollama_base_url: The base URL of the Ollama API.
stream: Whether to process the response as a stream (True) or wait for completion (False).
Returns:
True if the pull request was successful, False otherwise.
"""
settings = get_settings()
# Use provided base URL or get from settings
base_url = (ollama_base_url or settings.ollama_base_url).rstrip("/")
api_url = f"{base_url}/api/pull"
# Use the stream parameter as specified by the documentation
# If stream=False in the API request, Ollama will wait until download completes and return single response
payload = {"model": model_name, "stream": stream}
logger.info(f"Pulling model '{model_name}' from Ollama...")
try:
# Always use stream=True for requests to allow processing response in chunks
response = requests.post(api_url, json=payload, headers=_get_ollama_headers(), stream=True)
response.raise_for_status()
# If API stream=True, we'll get multiple status updates
if stream:
for line in response.iter_lines():
if not line:
continue
try:
status = json.loads(line.decode("utf-8"))
logger.info(f"Pull status: {status.get('status', 'unknown')}")
# Show progress percentage for downloads
if (
status.get("status", "").startswith("downloading")
and "total" in status
and "completed" in status
and status["total"] > 0
):
progress = (status["completed"] / status["total"]) * 100
logger.info(f"Download progress: {progress:.1f}%")
if status.get("status") == "success":
return True
except json.JSONDecodeError:
continue
# If we got here without returning True, the stream ended without success
logger.error("Model pull stream ended without success status")
return False
# If API stream=False, we'll get a single response at the end
else:
# Even with API stream=False, we still need to process the response
last_status = None
for line in response.iter_lines():
if line:
try:
status = json.loads(line.decode("utf-8"))
last_status = status
except json.JSONDecodeError:
continue
# Check final status
if last_status and last_status.get("status") == "success":
logger.info(f"Model '{model_name}' pulled successfully")
return True
else:
logger.error(f"Failed to pull model. Final status: {last_status}")
return False
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama API at {ollama_base_url or settings.ollama_base_url}. Is Ollama running?"
)
return False
except requests.exceptions.RequestException as e:
logger.error(f"Error pulling Ollama model: {e}")
if hasattr(e, "response") and e.response is not None:
logger.error(f"Response status code: {e.response.status_code}")
logger.error(f"Response text: {e.response.text}")
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI lifespan context manager that initializes models before startup
and cleans up resources on shutdown.
"""
# First check if the Ollama server is reachable at all
# Check server with verbose logging (this is only done once at startup)
server_reachable = check_ollama_server_reachable(verbose=False)
if server_reachable:
# Only try to ensure the model is available if the server is reachable
logger.info(
f"Ollama server at {get_settings().ollama_base_url} is reachable. Checking for required models..."
)
ensure_ollama_model(get_settings().llm_model)
else:
# Use critical level for more visibility in logs
logger.critical("⚠️ WARNING: OLLAMA SERVER NOT AVAILABLE ⚠️")
logger.critical(
"The application is starting with LIMITED FUNCTIONALITY. "
"LLM-dependent features (summarization and RAG) will NOT WORK "
"until the Ollama server becomes available."
)
logger.warning(
"Please ensure Ollama is running and accessible at: "
+ get_settings().ollama_base_url
)
# Continue with embedding model loading
logger.info("Loading embedding model during application startup...")
_get_model(get_settings().embedder_model)
logger.info("Embedding model loaded.")
# Warn if parallel workers exceed Ollama's parallel capacity
settings = get_settings()
ollama_num_parallel = int(os.environ.get("OLLAMA_NUM_PARALLEL", 1))
if settings.llm_max_workers > ollama_num_parallel:
logger.warning(
f"llm_max_workers={settings.llm_max_workers} exceeds OLLAMA_NUM_PARALLEL={ollama_num_parallel}. "
f"Requests will be queued by Ollama. For full parallelization, set "
f"OLLAMA_NUM_PARALLEL>={settings.llm_max_workers} when starting Ollama."
)
# Store server status for post-startup message
app.state.ollama_available = server_reachable
yield
# Comprehensive cleanup on shutdown
logger.info("Application shutting down, cleaning up resources...")
try:
with _model_lock:
# Clean up all models in the cache
for model_name in list(_model_cache.keys()):
if model_name in _model_cache:
logger.info(
f"Removing model {model_name} from cache during shutdown"
)
del _model_cache[model_name]
# Clear the last used tracking dictionary
_model_last_used.clear()
logger.info("Model cache and tracking dictionaries cleared")
# Clean up GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info(
f"GPU memory cleaned up. Current usage: {torch.cuda.memory_allocated() / (1024 * 1024):.2f} MB"
)
except Exception as e:
logger.error(f"Error during shutdown cleanup: {str(e)}")
# Don't re-raise the exception to avoid blocking shutdown
# Initialize FastAPI app.
app = FastAPI(
title="RAPTOR API",
description="API for Recursive Abstraction and Processing for Text Organization and Reduction",
version="1.0.0",
lifespan=lifespan,
)
# Suppress warnings.
warnings.filterwarnings("ignore", category=NumbaWarning)
warnings.filterwarnings("ignore", message=".*force_all_finite.*")
# Create logs directory if it doesn't exist.
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
# Configure logging.
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Get the logger.
logger = logging.getLogger(__name__)
# Create a file handler for error logs.
error_log_path = logs_dir / "errors.log"
file_handler = RotatingFileHandler(
error_log_path,
maxBytes=10485760, # 10 MB.
backupCount=5, # Keep 5 backup logs.
encoding="utf-8",
)
# Set the file handler to only log errors and critical messages.
file_handler.setLevel(logging.ERROR)
# Create a formatter.
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s - %(pathname)s:%(lineno)d"
)
file_handler.setFormatter(formatter)
# Add the handler to the logger.
logger.addHandler(file_handler)
# Auth setup.
API_TOKEN = os.environ.get("API_TOKEN", "").strip()
_security = HTTPBearer(auto_error=False)
# Create a singleton for model caching.
_model_cache = {}
_model_last_used = {} # Track when each model was last used
_model_lock = threading.RLock() # Thread-safe lock for model cache
# Cache timeout in seconds (1 hour)
MODEL_CACHE_TIMEOUT = int(os.environ.get("MODEL_CACHE_TIMEOUT", 3600))
# Thread-local storage for HTTP sessions (connection pooling)
_thread_local = threading.local()
async def verify_token(
credentials: HTTPAuthorizationCredentials = Depends(_security),
):
"""Verify Bearer token if API_TOKEN is configured. No-op when API_TOKEN is unset."""
if not API_TOKEN:
return # Auth disabled
if credentials is None or not hmac.compare_digest(credentials.credentials, API_TOKEN):
raise HTTPException(status_code=403, detail="Invalid or missing API token")
def _get_http_session() -> requests.Session:
"""Get or create a thread-local requests session with connection pooling.
Using thread-local sessions allows efficient connection reuse within
each thread while avoiding thread-safety issues.
Returns:
requests.Session: A configured session for HTTP requests.
"""
if not hasattr(_thread_local, "session"):
session = requests.Session()
# Configure connection pooling for parallel requests
from requests.adapters import HTTPAdapter
adapter = HTTPAdapter(
pool_connections=10,
pool_maxsize=10,
)
session.mount("http://", adapter)
session.mount("https://", adapter)
_thread_local.session = session
return _thread_local.session
# Set random seed.
random.seed(get_settings().random_seed)
def _get_model(model_name: str) -> SentenceTransformer:
"""Get model from cache or load it into RAM.
Args:
model_name (str): Name or path of the model to use.
Returns:
SentenceTransformer: The loaded model instance.
Raises:
ValueError: If model_name is None or empty.
RuntimeError: If there's an error loading the model from disk or downloading it.
Exception: For any other unexpected errors during model loading.
"""
if not model_name:
error_msg = "Model name cannot be None or empty"
logger.error(error_msg)
raise ValueError(error_msg)
current_time = time.time()
logger.debug(f"Requesting model: {model_name}")
# Track memory usage before any operations
if torch.cuda.is_available():
before_mem = torch.cuda.memory_allocated() / (1024 * 1024)
logger.debug(f"GPU memory before model operations: {before_mem:.2f} MB")
with _model_lock:
logger.debug(f"Acquired lock for model operations: {model_name}")
# Check for expired models and remove them
expired_count = 0
for name, last_used in list(_model_last_used.items()):
if current_time - last_used > MODEL_CACHE_TIMEOUT:
if name in _model_cache:
unused_minutes = int((current_time - last_used) / 60)
logger.info(
f"Removing expired model {name} from cache (unused for {unused_minutes} minutes)"
)
try:
del _model_cache[name]
del _model_last_used[name]
expired_count += 1
# Clean up GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
after_mem = torch.cuda.memory_allocated() / (1024 * 1024)
logger.info(f"GPU memory after cleanup: {after_mem:.2f} MB")
except Exception as cleanup_error:
logger.warning(
f"Error during expired model cleanup for {name}: {str(cleanup_error)}"
)
if expired_count > 0:
logger.info(f"Removed {expired_count} expired models from cache")
# Load model if not in cache
if model_name not in _model_cache:
logger.info(f"Model {model_name} not found in cache, loading...")
try:
# Create models directory if it doesn't exist.
models_dir = Path("models")
models_dir.mkdir(exist_ok=True)
# Local path for the model.
local_model_path = models_dir / model_name.replace("/", "_")
if local_model_path.exists():
# Load from local storage.
logger.info(f"Loading model from local storage: {local_model_path}")
start_time = time.time()
_model_cache[model_name] = SentenceTransformer(
str(local_model_path)
)
load_time = time.time() - start_time
logger.info(
f"Model {model_name} loaded from disk in {load_time:.2f} seconds"
)
else:
# Download and save model.
logger.info(
f"Downloading model {model_name} and saving to {local_model_path}"
)
start_time = time.time()
_model_cache[model_name] = SentenceTransformer(model_name)
download_time = time.time() - start_time
logger.info(
f"Model {model_name} downloaded in {download_time:.2f} seconds, saving to disk..."
)
save_start = time.time()
_model_cache[model_name].save(str(local_model_path))
save_time = time.time() - save_start
logger.info(
f"Model {model_name} saved to disk in {save_time:.2f} seconds"
)
# Log model size information
model_size = sum(
p.numel() * p.element_size()
for p in _model_cache[model_name].parameters()
) / (1024 * 1024)
logger.info(
f"Model {model_name} loaded, approximate size: {model_size:.2f} MB"
)
except FileNotFoundError as e:
error_msg = f"Model directory not accessible: {str(e)}"
logger.error(error_msg)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise RuntimeError(error_msg) from e
except (OSError, IOError) as e:
error_msg = f"I/O error loading model {model_name}: {str(e)}"
logger.error(error_msg)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error loading model {model_name}: {str(e)}"
logger.error(error_msg)
# Try to clean up memory in case of error
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("GPU memory cleaned up after model loading error")
raise
else:
logger.debug(f"Model {model_name} found in cache")
# Update last used timestamp
_model_last_used[model_name] = current_time
logger.debug(f"Updated last used timestamp for {model_name}")
# Log current cache status
logger.debug(f"Current model cache size: {len(_model_cache)} models")
return _model_cache[model_name]
def generate_summary(
chunk: str,
model: str = None,
instructions: str = None,
temperature: float = None,
context_window: int = None,
) -> str:
"""Generate a summary using OLLAMA.
Args:
chunk (str): The text chunk to summarize.
model (str, optional): The LLM model identifier to use. Defaults to value from settings.
instructions (str, optional): Custom instructions for summarization.
User does NOT need to include {chunk} - it's added automatically.
Defaults to DEFAULT_INSTRUCTIONS.
temperature (float, optional): Controls randomness in output (0.0 to 1.0). Defaults to TEMPERATURE.
context_window (int, optional): Maximum context window size. Defaults to CONTEXT_WINDOW.
Returns:
The generated summary content.
"""
# Get settings
settings = get_settings()
# Use provided values or defaults from settings
model = model or settings.llm_model
temperature = temperature if temperature is not None else settings.temperature
context_window = (
context_window if context_window is not None else settings.context_window
)
# Get the Ollama URL from environment settings
api_url = settings.ollama_api_generate_url
# Build the prompt using instructions + chunk wrapper
formatted_prompt = build_prompt(chunk, instructions)
headers = {"Content-Type": "application/json", **_get_ollama_headers()}
data = {
"model": model,
"prompt": formatted_prompt,
"temperature": temperature,
"options": {
"num_ctx": context_window,
"num_thread": settings.ollama_num_thread,
"num_gpu": settings.ollama_num_gpu,
"num_predict": settings.ollama_num_predict,
},
"stream": False,
}
# Retry configuration from settings
max_retries = settings.llm_max_retries
base_delay = settings.llm_base_delay
timeout = settings.llm_timeout
last_exception = None
# Implement retry with exponential backoff.
for attempt in range(max_retries):
try:
# Use session with connection pooling for efficient parallel requests
session = _get_http_session()
response = session.post(
api_url, headers=headers, data=json.dumps(data), timeout=timeout
)
if response.status_code == 200:
# Parse single JSON response (stream: False returns complete response)
try:
resp_json = response.json()
full_response = resp_json.get("response", "")
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response from OLLAMA: {e}")
if attempt == max_retries - 1:
return "Unable to generate summary due to invalid JSON response from OLLAMA."
continue
if not full_response.strip():
# If we got an empty response, log and retry.
logger.warning(
f"Empty response received from OLLAMA (attempt {attempt + 1}/{max_retries})"
)
if attempt == max_retries - 1:
return "Unable to generate summary due to empty response from OLLAMA."
else:
return full_response
else:
# Log the error and prepare for retry.
logger.warning(
f"Error response from OLLAMA (attempt {attempt + 1}/{max_retries}): "
f"Status {response.status_code}, Response: {response.text[:200]}..."
)
if attempt == max_retries - 1:
# On last attempt, raise the exception.
raise Exception(f"Error generating summary: {response.status_code}")
except requests.exceptions.Timeout as e:
logger.warning(
f"Timeout error connecting to OLLAMA (attempt {attempt + 1}/{max_retries}): {str(e)}"
)
last_exception = e
except requests.exceptions.ConnectionError as e:
logger.warning(
f"Connection error to OLLAMA (attempt {attempt + 1}/{max_retries}): {str(e)}"
)
last_exception = e
except requests.exceptions.RequestException as e:
logger.warning(
f"Request error to OLLAMA (attempt {attempt + 1}/{max_retries}): {str(e)}"
)
last_exception = e
except Exception as e:
logger.warning(
f"Unexpected error during OLLAMA request (attempt {attempt + 1}/{max_retries}): {str(e)}"
)
last_exception = e
# Only sleep if we're going to retry.
if attempt < max_retries - 1:
# Exponential backoff with jitter.
delay = base_delay * (2**attempt) + random.uniform(0, 1)
logger.info(f"Retrying in {delay:.2f} seconds...")
time.sleep(delay)
# If we've exhausted all retries, log the error and raise an exception.
logger.error(f"Failed to connect to OLLAMA after {max_retries} attempts")
if last_exception:
raise Exception(
f"Failed to connect to OLLAMA after {max_retries} attempts: {str(last_exception)}"
)
else:
raise Exception(f"Failed to connect to OLLAMA after {max_retries} attempts")
def parallel_generate_summaries(
chunks: List[str],
model: str,
instructions: str,
temperature: float,
context_window: int,
desc: str = "Generating summaries",
) -> List[str]:
"""Generate summaries in parallel using ThreadPoolExecutor.
Uses as_completed for efficient processing while preserving original order.
Args:
chunks: List of text chunks to summarize
model: LLM model to use
instructions: Custom instructions for summarization (chunk is added automatically)
temperature: Temperature for generation
context_window: Context window size
desc: Description for progress bar
Returns:
List of generated summaries in original order
"""
if not chunks:
return []
settings = get_settings()
max_workers = min(settings.llm_max_workers, len(chunks))
logger.info(
f"Generating {len(chunks)} summaries using {max_workers} parallel workers"
)
def _generate_single(chunk: str) -> str:
return generate_summary(
chunk=chunk,
model=model,
instructions=instructions,
temperature=temperature,
context_window=context_window,
)
# Use dict to preserve order while processing as_completed
summaries_dict: Dict[int, str] = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Map futures to their original indices
future_to_index = {
executor.submit(_generate_single, chunk): i
for i, chunk in enumerate(chunks)
}
# Process results as they complete (more efficient)
for future in tqdm(
as_completed(future_to_index), desc=desc, unit="chunk", total=len(chunks)
):
index = future_to_index[future]
try:
summaries_dict[index] = future.result()
except Exception as e:
logger.error(f"Error generating summary for chunk {index}: {e}")
raise RuntimeError(
f"Summary generation failed for chunk {index} after all retries: {str(e)}"
) from e
# Reconstruct list in original order
return [summaries_dict[i] for i in range(len(chunks))]
def get_embeddings(
chunks: List[str],
model: str = None,
batch_size: int = 4,
show_progress_bar: bool = True,
convert_to_numpy: bool = True,
normalize_embeddings: bool = True,
) -> np.ndarray:
"""Generate embeddings for text chunks using a Sentence Transformer model.
Args:
chunks: List of text chunks to generate embeddings for.
model (str, optional): Name or path of the model to use.
Defaults to "BAAI/bge-m3".
batch_size (int, optional): Batch size for embedding generation.
Defaults to 4.
show_progress_bar (bool, optional): Whether to show progress bar.
Defaults to True.
convert_to_numpy (bool, optional): Whether to convert output to numpy array.
Defaults to True.
normalize_embeddings (bool, optional): Whether to normalize embeddings.
Defaults to True.
Returns:
A numpy array of embeddings.
Raises:
HTTPException: If there's an error during the embedding process.
"""
# Use settings model if none provided
if model is None:
model = get_settings().embedder_model
# Adjust batch size dynamically based on document size
adjusted_batch_size = batch_size
if len(chunks) > 1000:
adjusted_batch_size = max(1, batch_size // 2)
logger.info(
f"Large document detected ({len(chunks)} chunks), reducing batch size to {adjusted_batch_size}"
)
# Log memory usage before processing if GPU is available