11import asyncio
2- import copy
3- import datetime
42import json
5- import uuid
63from pathlib import Path
7- from typing import AsyncGenerator , AsyncIterator , List , Optional
4+ from typing import List , Optional
85
96import structlog
10- from litellm import ChatCompletionRequest , ModelResponse
117from pydantic import BaseModel
128from sqlalchemy import text
139from sqlalchemy .ext .asyncio import create_async_engine
@@ -35,7 +31,7 @@ def __init__(self, sqlite_path: Optional[str] = None):
3531 )
3632 self ._db_path = Path (sqlite_path ).absolute ()
3733 self ._db_path .parent .mkdir (parents = True , exist_ok = True )
38- logger .debug (f"Initializing DB from path: { self ._db_path } " )
34+ logger .info (f"Initializing DB from path: { self ._db_path } " )
3935 engine_dict = {
4036 "url" : f"sqlite+aiosqlite:///{ self ._db_path } " ,
4137 "echo" : False , # Set to False in production
@@ -104,9 +100,7 @@ async def _insert_pydantic_model(
104100 logger .error (f"Failed to insert model: { model } ." , error = str (e ))
105101 return None
106102
107- async def record_request (
108- self , prompt_params : Optional [Prompt ] = None
109- ) -> Optional [Prompt ]:
103+ async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
110104 if prompt_params is None :
111105 return None
112106 sql = text (
@@ -117,87 +111,38 @@ async def record_request(
117111 """
118112 )
119113 recorded_request = await self ._insert_pydantic_model (prompt_params , sql )
120- logger .info (f"Recorded request: { recorded_request } " )
114+ logger .debug (f"Recorded request: { recorded_request } " )
121115 return recorded_request
122116
123- async def _record_output (self , prompt : Prompt , output_str : str ) -> Optional [Output ]:
124- output_params = Output (
125- id = str (uuid .uuid4 ()),
126- prompt_id = prompt .id ,
127- timestamp = datetime .datetime .now (datetime .timezone .utc ),
128- output = output_str ,
129- )
130- sql = text (
131- """
132- INSERT INTO outputs (id, prompt_id, timestamp, output)
133- VALUES (:id, :prompt_id, :timestamp, :output)
134- RETURNING *
135- """
136- )
137- return await self ._insert_pydantic_model (output_params , sql )
138-
139- async def record_outputs (self , outputs : List [Output ]) -> List [Output ]:
117+ async def record_outputs (self , outputs : List [Output ]) -> Optional [Output ]:
140118 if not outputs :
141119 return
120+
121+ first_output = outputs [0 ]
122+ # Create a single entry on DB but encode all of the chunks in the stream as a list
123+ # of JSON objects in the field `output`
124+ output_db = Output (
125+ id = first_output .id ,
126+ prompt_id = first_output .prompt_id ,
127+ timestamp = first_output .timestamp ,
128+ output = first_output .output ,
129+ )
130+ full_outputs = []
131+ # Just store the model respnses in the list of JSON objects.
132+ for output in outputs :
133+ full_outputs .append (output .output )
134+ output_db .output = json .dumps (full_outputs )
135+
142136 sql = text (
143137 """
144138 INSERT INTO outputs (id, prompt_id, timestamp, output)
145139 VALUES (:id, :prompt_id, :timestamp, :output)
146140 RETURNING *
147141 """
148142 )
149- # We can insert each alert independently in parallel.
150- outputs_tasks = []
151- async with asyncio .TaskGroup () as tg :
152- for output in outputs :
153- try :
154- outputs_tasks .append (tg .create_task (self ._insert_pydantic_model (output , sql )))
155- except Exception as e :
156- logger .error (f"Failed to record alert: { output } ." , error = str (e ))
157- recorded_outputs = [output .result () for output in outputs_tasks ]
158- logger .info (f"Recorded outputs: { recorded_outputs } " )
159- return recorded_outputs
160-
161- async def record_output_stream (
162- self , prompt : Prompt , model_response : AsyncIterator
163- ) -> AsyncGenerator :
164- output_chunks = []
165- async for chunk in model_response :
166- if isinstance (chunk , BaseModel ):
167- chunk_to_record = chunk .model_dump (exclude_none = True , exclude_unset = True )
168- output_chunks .append (chunk_to_record )
169- elif isinstance (chunk , dict ):
170- output_chunks .append (copy .deepcopy (chunk ))
171- else :
172- output_chunks .append ({"chunk" : str (chunk )})
173- yield chunk
174-
175- if output_chunks :
176- # Record the output chunks
177- output_str = json .dumps (output_chunks )
178- await self ._record_output (prompt , output_str )
179-
180- async def record_output_non_stream (
181- self , prompt : Optional [Prompt ], model_response : ModelResponse
182- ) -> Optional [Output ]:
183- if prompt is None :
184- logger .warning ("No prompt found to record output." )
185- return
186-
187- output_str = None
188- if isinstance (model_response , BaseModel ):
189- output_str = model_response .model_dump_json (exclude_none = True , exclude_unset = True )
190- else :
191- try :
192- output_str = json .dumps (model_response )
193- except Exception as e :
194- logger .error (f"Failed to serialize output: { model_response } " , error = str (e ))
195-
196- if output_str is None :
197- logger .warning ("No output found to record." )
198- return
199-
200- return await self ._record_output (prompt , output_str )
143+ recorded_output = await self ._insert_pydantic_model (output_db , sql )
144+ logger .debug (f"Recorded output: { recorded_output } " )
145+ return recorded_output
201146
202147 async def record_alerts (self , alerts : List [Alert ]) -> List [Alert ]:
203148 if not alerts :
@@ -220,16 +165,24 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
220165 try :
221166 result = tg .create_task (self ._insert_pydantic_model (alert , sql ))
222167 alerts_tasks .append (result )
223- if result and alert .trigger_category == "critical" :
224- await alert_queue .put (f"New alert detected: { alert .timestamp } " )
225168 except Exception as e :
226169 logger .error (f"Failed to record alert: { alert } ." , error = str (e ))
227- recorded_alerts = [alert .result () for alert in alerts_tasks ]
228- logger .info (f"Recorded alerts: { recorded_alerts } " )
170+
171+ recorded_alerts = []
172+ for alert_coro in alerts_tasks :
173+ alert_result = alert_coro .result ()
174+ recorded_alerts .append (alert_result )
175+ if alert_result and alert_result .trigger_category == "critical" :
176+ await alert_queue .put (f"New alert detected: { alert .timestamp } " )
177+
178+ logger .debug (f"Recorded alerts: { recorded_alerts } " )
229179 return recorded_alerts
230180
231181 async def record_context (self , context : PipelineContext ) -> None :
232- logger .info (f"Recording context: { context } " )
182+ logger .info (
183+ f"Recording context in DB. Output chunks: { len (context .output_responses )} . "
184+ f"Alerts: { len (context .alerts_raised )} ."
185+ )
233186 await self .record_request (context .input_request )
234187 await self .record_outputs (context .output_responses )
235188 await self .record_alerts (context .alerts_raised )
0 commit comments