44import sys
55import threading
66import warnings
7+ from contextlib import asynccontextmanager
78from json import JSONDecodeError
8- from typing import AsyncGenerator , Dict , Iterator , Optional , Tuple , Union , overload
9+ from typing import (
10+ AsyncGenerator ,
11+ AsyncIterator ,
12+ Dict ,
13+ Iterator ,
14+ Optional ,
15+ Tuple ,
16+ Union ,
17+ overload ,
18+ )
919from urllib .parse import urlencode , urlsplit , urlunsplit
1020
1121import aiohttp
@@ -284,17 +294,19 @@ async def arequest(
284294 request_id : Optional [str ] = None ,
285295 request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
286296 ) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
287- result = await self .arequest_raw (
288- method .lower (),
289- url ,
290- params = params ,
291- supplied_headers = headers ,
292- files = files ,
293- request_id = request_id ,
294- request_timeout = request_timeout ,
295- )
296- resp , got_stream = await self ._interpret_async_response (result , stream )
297- return resp , got_stream , self .api_key
297+ async with aiohttp_session () as session :
298+ result = await self .arequest_raw (
299+ method .lower (),
300+ url ,
301+ session ,
302+ params = params ,
303+ supplied_headers = headers ,
304+ files = files ,
305+ request_id = request_id ,
306+ request_timeout = request_timeout ,
307+ )
308+ resp , got_stream = await self ._interpret_async_response (result , stream )
309+ return resp , got_stream , self .api_key
298310
299311 def handle_error_response (self , rbody , rcode , resp , rheaders , stream_error = False ):
300312 try :
@@ -514,6 +526,7 @@ async def arequest_raw(
514526 self ,
515527 method ,
516528 url ,
529+ session ,
517530 * ,
518531 params = None ,
519532 supplied_headers : Optional [Dict [str , str ]] = None ,
@@ -534,7 +547,6 @@ async def arequest_raw(
534547 timeout = aiohttp .ClientTimeout (
535548 total = request_timeout if request_timeout else TIMEOUT_SECS
536549 )
537- user_set_session = openai .aiosession .get ()
538550
539551 if files :
540552 # TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
@@ -552,11 +564,7 @@ async def arequest_raw(
552564 "timeout" : timeout ,
553565 }
554566 try :
555- if user_set_session :
556- result = await user_set_session .request (** request_kwargs )
557- else :
558- async with aiohttp .ClientSession () as session :
559- result = await session .request (** request_kwargs )
567+ result = await session .request (** request_kwargs )
560568 util .log_info (
561569 "OpenAI API response" ,
562570 path = abs_url ,
@@ -648,3 +656,13 @@ def _interpret_response_line(
648656 rbody , rcode , resp .data , rheaders , stream_error = stream_error
649657 )
650658 return resp
659+
660+
661+ @asynccontextmanager
662+ async def aiohttp_session () -> AsyncIterator [aiohttp .ClientSession ]:
663+ user_set_session = openai .aiosession .get ()
664+ if user_set_session :
665+ yield user_set_session
666+ else :
667+ async with aiohttp .ClientSession () as session :
668+ yield session
0 commit comments