11import asyncio
2- import copy
3- import datetime
42import json
5- import uuid
63from pathlib import Path
7- from typing import AsyncGenerator , AsyncIterator , List , Optional
4+ from typing import List , Optional
85
96import structlog
10- from litellm import ChatCompletionRequest , ModelResponse
117from pydantic import BaseModel
128from sqlalchemy import text
139from sqlalchemy .ext .asyncio import create_async_engine
1814 GetAlertsWithPromptAndOutputRow ,
1915 GetPromptWithOutputsRow ,
2016)
17+ from codegate .pipeline .base import PipelineContext
2118
2219logger = structlog .get_logger ("codegate" )
2320alert_queue = asyncio .Queue ()
@@ -103,97 +100,51 @@ async def _insert_pydantic_model(
103100 logger .error (f"Failed to insert model: { model } ." , error = str (e ))
104101 return None
105102
106- async def record_request (
107- self , normalized_request : ChatCompletionRequest , is_fim_request : bool , provider_str : str
108- ) -> Optional [Prompt ]:
109- request_str = None
110- if isinstance (normalized_request , BaseModel ):
111- request_str = normalized_request .model_dump_json (exclude_none = True , exclude_unset = True )
112- else :
113- try :
114- request_str = json .dumps (normalized_request )
115- except Exception as e :
116- logger .error (f"Failed to serialize output: { normalized_request } " , error = str (e ))
117-
118- if request_str is None :
119- logger .warning ("No request found to record." )
120- return
121-
122- # Create a new prompt record
123- prompt_params = Prompt (
124- id = str (uuid .uuid4 ()), # Generate a new UUID for the prompt
125- timestamp = datetime .datetime .now (datetime .timezone .utc ),
126- provider = provider_str ,
127- type = "fim" if is_fim_request else "chat" ,
128- request = request_str ,
129- )
103+ async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
104+ if prompt_params is None :
105+ return None
130106 sql = text (
131107 """
132108 INSERT INTO prompts (id, timestamp, provider, request, type)
133109 VALUES (:id, :timestamp, :provider, :request, :type)
134110 RETURNING *
135111 """
136112 )
137- return await self ._insert_pydantic_model (prompt_params , sql )
138-
139- async def _record_output (self , prompt : Prompt , output_str : str ) -> Optional [Output ]:
140- output_params = Output (
141- id = str (uuid .uuid4 ()),
142- prompt_id = prompt .id ,
143- timestamp = datetime .datetime .now (datetime .timezone .utc ),
144- output = output_str ,
113+ recorded_request = await self ._insert_pydantic_model (prompt_params , sql )
114+ logger .debug (f"Recorded request: { recorded_request } " )
115+ return recorded_request
116+
117+ async def record_outputs (self , outputs : List [Output ]) -> Optional [Output ]:
118+ if not outputs :
119+ return
120+
121+ first_output = outputs [0 ]
122+ # Create a single entry on DB but encode all of the chunks in the stream as a list
123+ # of JSON objects in the field `output`
124+ output_db = Output (
125+ id = first_output .id ,
126+ prompt_id = first_output .prompt_id ,
127+ timestamp = first_output .timestamp ,
128+ output = first_output .output ,
145129 )
130+ full_outputs = []
131+ # Just store the model respnses in the list of JSON objects.
132+ for output in outputs :
133+ full_outputs .append (output .output )
134+ output_db .output = json .dumps (full_outputs )
135+
146136 sql = text (
147137 """
148138 INSERT INTO outputs (id, prompt_id, timestamp, output)
149139 VALUES (:id, :prompt_id, :timestamp, :output)
150140 RETURNING *
151141 """
152142 )
153- return await self ._insert_pydantic_model (output_params , sql )
154-
155- async def record_output_stream (
156- self , prompt : Prompt , model_response : AsyncIterator
157- ) -> AsyncGenerator :
158- output_chunks = []
159- async for chunk in model_response :
160- if isinstance (chunk , BaseModel ):
161- chunk_to_record = chunk .model_dump (exclude_none = True , exclude_unset = True )
162- output_chunks .append (chunk_to_record )
163- elif isinstance (chunk , dict ):
164- output_chunks .append (copy .deepcopy (chunk ))
165- else :
166- output_chunks .append ({"chunk" : str (chunk )})
167- yield chunk
168-
169- if output_chunks :
170- # Record the output chunks
171- output_str = json .dumps (output_chunks )
172- await self ._record_output (prompt , output_str )
173-
174- async def record_output_non_stream (
175- self , prompt : Optional [Prompt ], model_response : ModelResponse
176- ) -> Optional [Output ]:
177- if prompt is None :
178- logger .warning ("No prompt found to record output." )
179- return
143+ recorded_output = await self ._insert_pydantic_model (output_db , sql )
144+ logger .debug (f"Recorded output: { recorded_output } " )
145+ return recorded_output
180146
181- output_str = None
182- if isinstance (model_response , BaseModel ):
183- output_str = model_response .model_dump_json (exclude_none = True , exclude_unset = True )
184- else :
185- try :
186- output_str = json .dumps (model_response )
187- except Exception as e :
188- logger .error (f"Failed to serialize output: { model_response } " , error = str (e ))
189-
190- if output_str is None :
191- logger .warning ("No output found to record." )
192- return
193-
194- return await self ._record_output (prompt , output_str )
195-
196- async def record_alerts (self , alerts : List [Alert ]) -> None :
147+ async def record_alerts (self , alerts : List [Alert ]) -> List [Alert ]:
197148 if not alerts :
198149 return
199150 sql = text (
@@ -208,15 +159,33 @@ async def record_alerts(self, alerts: List[Alert]) -> None:
208159 """
209160 )
210161 # We can insert each alert independently in parallel.
162+ alerts_tasks = []
211163 async with asyncio .TaskGroup () as tg :
212164 for alert in alerts :
213165 try :
214166 result = tg .create_task (self ._insert_pydantic_model (alert , sql ))
215- if result and alert .trigger_category == "critical" :
216- await alert_queue .put (f"New alert detected: { alert .timestamp } " )
167+ alerts_tasks .append (result )
217168 except Exception as e :
218169 logger .error (f"Failed to record alert: { alert } ." , error = str (e ))
219- return None
170+
171+ recorded_alerts = []
172+ for alert_coro in alerts_tasks :
173+ alert_result = alert_coro .result ()
174+ recorded_alerts .append (alert_result )
175+ if alert_result and alert_result .trigger_category == "critical" :
176+ await alert_queue .put (f"New alert detected: { alert .timestamp } " )
177+
178+ logger .debug (f"Recorded alerts: { recorded_alerts } " )
179+ return recorded_alerts
180+
181+ async def record_context (self , context : PipelineContext ) -> None :
182+ logger .info (
183+ f"Recording context in DB. Output chunks: { len (context .output_responses )} . "
184+ f"Alerts: { len (context .alerts_raised )} ."
185+ )
186+ await self .record_request (context .input_request )
187+ await self .record_outputs (context .output_responses )
188+ await self .record_alerts (context .alerts_raised )
220189
221190
222191class DbReader (DbCodeGate ):
0 commit comments