Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit e6521ba

Browse files
authored
Merge pull request #375 from jhrozek/close_async_task
Cancel the processing_task on connection_lost and on exception
2 parents 9626bfd + 71e6bc9 commit e6521ba

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/codegate/providers/copilot/provider.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ def __init__(self, proxy: CopilotProvider):
580580
self.sse_processor: Optional[SSEProcessor] = None
581581
self.output_pipeline_instance: Optional[OutputPipelineInstance] = None
582582
self.stream_queue: Optional[asyncio.Queue] = None
583+
self.processing_task: Optional[asyncio.Task] = None
583584

584585
def connection_made(self, transport: asyncio.Transport) -> None:
585586
"""Handle successful connection to target"""
@@ -629,9 +630,7 @@ async def stream_iterator():
629630
StreamingChoices(
630631
finish_reason=choice.get("finish_reason", None),
631632
index=0,
632-
delta=Delta(
633-
content=content, role="assistant"
634-
),
633+
delta=Delta(content=content, role="assistant"),
635634
logprobs=None,
636635
)
637636
)
@@ -663,8 +662,17 @@ async def stream_iterator():
663662
# Now send the final zero chunk
664663
self._proxy_transport_write(b"0\r\n\r\n")
665664

665+
except asyncio.CancelledError:
666+
logger.debug("Stream processing cancelled")
667+
raise
666668
except Exception as e:
667669
logger.error(f"Error processing stream: {e}")
670+
finally:
671+
# Clean up
672+
if self.processing_task and not self.processing_task.done():
673+
self.processing_task.cancel()
674+
if self.proxy.context_tracking and self.proxy.context_tracking.sensitive:
675+
self.proxy.context_tracking.sensitive.secure_cleanup()
668676

669677
def _process_chunk(self, chunk: bytes):
670678
records = self.sse_processor.process_chunk(chunk)
@@ -709,6 +717,7 @@ def data_received(self, data: bytes) -> None:
709717

710718
def connection_lost(self, exc: Optional[Exception]) -> None:
711719
"""Handle connection loss to target"""
720+
712721
if (
713722
not self.proxy._closing
714723
and self.proxy.transport
@@ -719,4 +728,8 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
719728
except Exception as e:
720729
logger.error(f"Error closing proxy transport: {e}")
721730

722-
# todo: clear the context to erase the sensitive data
731+
# Clean up resources
732+
if self.processing_task and not self.processing_task.done():
733+
self.processing_task.cancel()
734+
if self.proxy.context_tracking and self.proxy.context_tracking.sensitive:
735+
self.proxy.context_tracking.sensitive.secure_cleanup()

0 commit comments

Comments
 (0)