11import asyncio
2+ import json
23import re
34import ssl
45from dataclasses import dataclass
1112from codegate .config import Config
1213from codegate .pipeline .base import PipelineContext
1314from codegate .pipeline .factory import PipelineFactory
15+ from codegate .pipeline .output import OutputPipelineInstance
1416from codegate .pipeline .secrets .manager import SecretsManager
1517from codegate .providers .copilot .mapping import VALIDATED_ROUTES
1618from codegate .providers .copilot .pipeline import (
1719 CopilotChatPipeline ,
1820 CopilotFimPipeline ,
1921 CopilotPipeline ,
2022)
23+ from codegate .providers .copilot .streaming import SSEProcessor
2124
2225logger = structlog .get_logger ("codegate" )
2326
@@ -139,13 +142,13 @@ async def _body_through_pipeline(
139142 path : str ,
140143 headers : list [str ],
141144 body : bytes ,
142- ) -> bytes :
145+ ) -> ( bytes , PipelineContext ) :
143146 logger .debug (f"Processing body through pipeline: { len (body )} bytes" )
144147 strategy = self ._select_pipeline (method , path )
145148 if strategy is None :
146149 # if we didn't select any strategy that would change the request
147150 # let's just pass through the body as-is
148- return body
151+ return body , None
149152 return await strategy .process_body (headers , body )
150153
151154 async def _request_to_target (self , headers : list [str ], body : bytes ):
@@ -154,13 +157,16 @@ async def _request_to_target(self, headers: list[str], body: bytes):
154157 ).encode ()
155158 logger .debug (f"Request Line: { request_line } " )
156159
157- body = await self ._body_through_pipeline (
160+ body , context = await self ._body_through_pipeline (
158161 self .request .method ,
159162 self .request .path ,
160163 headers ,
161164 body ,
162165 )
163166
167+ if context :
168+ self .context_tracking = context
169+
164170 for header in headers :
165171 if header .lower ().startswith ("content-length:" ):
166172 headers .remove (header )
@@ -243,12 +249,13 @@ async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
243249 # we couldn't parse this into an HTTP request, so we just pass through
244250 return data
245251
246- http_request .body = await self ._body_through_pipeline (
252+ http_request .body , context = await self ._body_through_pipeline (
247253 http_request .method ,
248254 http_request .path ,
249255 http_request .headers ,
250256 http_request .body ,
251257 )
258+ self .context_tracking = context
252259
253260 for header in http_request .headers :
254261 if header .lower ().startswith ("content-length:" ):
@@ -549,15 +556,68 @@ def __init__(self, proxy: CopilotProvider):
549556 self .proxy = proxy
550557 self .transport : Optional [asyncio .Transport ] = None
551558
559+ self .headers_sent = False
560+ self .sse_processor : Optional [SSEProcessor ] = None
561+ self .output_pipeline_instance : Optional [OutputPipelineInstance ] = None
562+
552563 def connection_made (self , transport : asyncio .Transport ) -> None :
553564 """Handle successful connection to target"""
554565 self .transport = transport
555566 self .proxy .target_transport = transport
556567
568+ def _process_chunk (self , chunk : bytes ):
569+ records = self .sse_processor .process_chunk (chunk )
570+
571+ for record in records :
572+ if record ["type" ] == "done" :
573+ sse_data = b"data: [DONE]\n \n "
574+ # Add chunk size for DONE message too
575+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
576+ self ._proxy_transport_write (chunk_size .encode ())
577+ self ._proxy_transport_write (sse_data )
578+ self ._proxy_transport_write (b"\r \n " )
579+ # Now send the final zero chunk
580+ self ._proxy_transport_write (b"0\r \n \r \n " )
581+ else :
582+ sse_data = f"data: { json .dumps (record ['content' ])} \n \n " .encode ("utf-8" )
583+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
584+ self ._proxy_transport_write (chunk_size .encode ())
585+ self ._proxy_transport_write (sse_data )
586+ self ._proxy_transport_write (b"\r \n " )
587+
588+ def _proxy_transport_write (self , data : bytes ):
589+ self .proxy .transport .write (data )
590+
557591 def data_received (self , data : bytes ) -> None :
558592 """Handle data received from target"""
593+ if self .proxy .context_tracking is not None and self .sse_processor is None :
594+ logger .debug ("Tracking context for pipeline processing" )
595+ self .sse_processor = SSEProcessor ()
596+ out_pipeline_processor = self .proxy .pipeline_factory .create_output_pipeline ()
597+ self .output_pipeline_instance = OutputPipelineInstance (
598+ pipeline_steps = out_pipeline_processor .pipeline_steps ,
599+ input_context = self .proxy .context_tracking ,
600+ )
601+
559602 if self .proxy .transport and not self .proxy .transport .is_closing ():
560- self .proxy .transport .write (data )
603+ if not self .sse_processor :
604+ # Pass through non-SSE data unchanged
605+ self .proxy .transport .write (data )
606+ return
607+
608+ # Check if this is the first chunk with headers
609+ if not self .headers_sent :
610+ header_end = data .find (b"\r \n \r \n " )
611+ if header_end != - 1 :
612+ self .headers_sent = True
613+ # Send headers first
614+ headers = data [: header_end + 4 ]
615+ self ._proxy_transport_write (headers )
616+ logger .debug (f"Headers sent: { headers } " )
617+
618+ data = data [header_end + 4 :]
619+
620+ self ._process_chunk (data )
561621
562622 def connection_lost (self , exc : Optional [Exception ]) -> None :
563623 """Handle connection loss to target"""
@@ -570,3 +630,5 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
570630 self .proxy .transport .close ()
571631 except Exception as e :
572632 logger .error (f"Error closing proxy transport: { e } " )
633+
634+ # todo: clear the context to erase the sensitive data
0 commit comments