5252 IVectorize ,
5353 I3dInference ,
5454 I3d ,
55+ ITextInference ,
56+ IText ,
5557)
5658from .types import IImage , IError , SdkType , ListenerType
5759from .utils import (
7577 process_image ,
7678 createAsyncTaskResponse ,
7779 VIDEO_INITIAL_TIMEOUT ,
80+ TEXT_INITIAL_TIMEOUT ,
7881 VIDEO_POLLING_DELAY ,
7982 WEBHOOK_TIMEOUT ,
8083 IMAGE_INFERENCE_TIMEOUT ,
8487 MODEL_UPLOAD_TIMEOUT ,
8588 IMAGE_INITIAL_TIMEOUT ,
8689 IMAGE_POLLING_DELAY ,
90+ TEXT_POLLING_DELAY ,
8791 AUDIO_INITIAL_TIMEOUT ,
8892 AUDIO_INFERENCE_TIMEOUT ,
8993 AUDIO_POLLING_DELAY ,
@@ -1880,18 +1884,25 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync
18801884 await self .ensureConnection ()
18811885 return await self ._request3d (request3d )
18821886
1887+ async def textInference (self , requestText : ITextInference ) -> Union [List [IText ], IAsyncTaskResponse ]:
1888+ return await self ._retry_with_reconnect (self ._textInference , requestText )
1889+
1890+ async def _textInference (self , requestText : ITextInference ) -> Union [List [IText ], IAsyncTaskResponse ]:
1891+ await self .ensureConnection ()
1892+ return await self ._requestText (requestText )
1893+
18831894 async def getResponse (
18841895 self ,
18851896 taskUUID : str ,
18861897 numberResults : Optional [int ] = 1 ,
1887- ) -> Union [List [IVideo ], List [IAudio ], List [IVideoToText ], List [IImage ], List [I3d ]]:
1898+ ) -> Union [List [IVideo ], List [IAudio ], List [IVideoToText ], List [IImage ], List [I3d ], List [ IText ] ]:
18881899 return await self ._retry_with_reconnect (self ._getResponse , taskUUID , numberResults )
18891900
18901901 async def _getResponse (
18911902 self ,
18921903 taskUUID : str ,
18931904 numberResults : Optional [int ] = 1 ,
1894- ) -> Union [List [IVideo ], List [IAudio ], List [IVideoToText ], List [IImage ], List [I3d ]]:
1905+ ) -> Union [List [IVideo ], List [IAudio ], List [IVideoToText ], List [IImage ], List [I3d ], List [ IText ] ]:
18951906 await self .ensureConnection ()
18961907
18971908 return await self ._pollResults (
@@ -2059,6 +2070,121 @@ async def _request3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTa
20592070 "3d-inference-initial" ,
20602071 )
20612072
2073+ def _buildTextRequest (self , requestText : ITextInference ) -> Dict [str , Any ]:
2074+ request_object : Dict [str , Any ] = {
2075+ "taskType" : ETaskType .TEXT_INFERENCE .value ,
2076+ "taskUUID" : requestText .taskUUID ,
2077+ "model" : requestText .model ,
2078+ "deliveryMethod" : requestText .deliveryMethod ,
2079+ "messages" : [asdict (m ) for m in requestText .messages ],
2080+ }
2081+ if requestText .maxTokens is not None :
2082+ request_object ["maxTokens" ] = requestText .maxTokens
2083+ if requestText .temperature is not None :
2084+ request_object ["temperature" ] = requestText .temperature
2085+ if requestText .topP is not None :
2086+ request_object ["topP" ] = requestText .topP
2087+ if requestText .topK is not None :
2088+ request_object ["topK" ] = requestText .topK
2089+ if requestText .seed is not None :
2090+ request_object ["seed" ] = requestText .seed
2091+ if requestText .stopSequences is not None :
2092+ request_object ["stopSequences" ] = requestText .stopSequences
2093+ if requestText .includeCost is not None :
2094+ request_object ["includeCost" ] = requestText .includeCost
2095+ self ._addTextProviderSettings (request_object , requestText )
2096+ return request_object
2097+
2098+ async def _requestText (self , requestText : ITextInference ) -> Union [List [IText ], IAsyncTaskResponse ]:
2099+ requestText .taskUUID = requestText .taskUUID or getUUID ()
2100+ request_object = self ._buildTextRequest (requestText )
2101+ await self .send ([request_object ])
2102+ return await self ._handleInitialTextResponse (
2103+ requestText .taskUUID ,
2104+ requestText .deliveryMethod ,
2105+ "text-inference-initial" ,
2106+ )
2107+
2108+ async def _handleInitialTextResponse (
2109+ self ,
2110+ task_uuid : str ,
2111+ delivery_method : Union [str , EDeliveryMethod ] = EDeliveryMethod .SYNC ,
2112+ debug_key : str = "text-inference-initial" ,
2113+ ) -> Union [List [IText ], IAsyncTaskResponse ]:
2114+ lis = self .globalListener (taskUUID = task_uuid )
2115+ delivery_method_enum = delivery_method if isinstance (delivery_method , EDeliveryMethod ) else EDeliveryMethod (delivery_method )
2116+
2117+ async def check_initial_response (resolve : callable , reject : callable , * args : Any ) -> bool :
2118+ if not self .connected () or not self .isWebsocketReadyState ():
2119+ reject (ConnectionError (
2120+ f"Connection lost while waiting for text response | "
2121+ f"TaskUUID: { task_uuid } | "
2122+ f"Delivery method: { delivery_method_enum } "
2123+ ))
2124+ return True
2125+
2126+ async with self ._messages_lock :
2127+ response_list = self ._globalMessages .get (task_uuid , [])
2128+
2129+ if not response_list :
2130+ return False
2131+
2132+ response = response_list [0 ]
2133+
2134+ if self ._is_error_response (response ):
2135+ del self ._globalMessages [task_uuid ]
2136+ raise RunwareAPIError (response )
2137+
2138+ if response .get ("status" ) == "success" or response .get ("text" ) is not None :
2139+ del self ._globalMessages [task_uuid ]
2140+ resolve ([response ])
2141+ return True
2142+
2143+ if delivery_method_enum is EDeliveryMethod .ASYNC :
2144+ del self ._globalMessages [task_uuid ]
2145+ async_response = createAsyncTaskResponse (response )
2146+ resolve ([async_response ])
2147+ return True
2148+
2149+ return False
2150+
2151+ try :
2152+ initial_response = await getIntervalWithPromise (
2153+ check_initial_response ,
2154+ debugKey = debug_key ,
2155+ timeOutDuration = TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod .SYNC else TEXT_INITIAL_TIMEOUT ,
2156+ )
2157+ except RunwareAPIError :
2158+ raise
2159+ except Exception as e :
2160+ if not self .connected () or not self .isWebsocketReadyState ():
2161+ raise ConnectionError (
2162+ f"Connection lost while waiting for text response | "
2163+ f"TaskUUID: { task_uuid } | "
2164+ f"Delivery method: { delivery_method_enum } "
2165+ )
2166+ if delivery_method_enum is EDeliveryMethod .SYNC :
2167+ error_msg = (
2168+ f"Timeout waiting for text generation | "
2169+ f"TaskUUID: { task_uuid } | "
2170+ f"Timeout: { TIMEOUT_DURATION } ms | "
2171+ f"Original error: { str (e )} "
2172+ )
2173+ raise ConnectionError (error_msg )
2174+ initial_response = None
2175+ finally :
2176+ lis ["destroy" ]()
2177+
2178+ if not initial_response or len (initial_response ) == 0 :
2179+ raise ConnectionError (
2180+ f"No initial response received for text generation | delivery_method={ delivery_method_enum } | taskUUID={ task_uuid } "
2181+ )
2182+
2183+ if isinstance (initial_response [0 ], IAsyncTaskResponse ):
2184+ return initial_response [0 ]
2185+
2186+ return instantiateDataclassList (IText , initial_response )
2187+
20622188 def _buildImageRequest (self , requestImage : IImageInference , prompt : Optional [str ], control_net_data_dicts : List [Dict ], instant_id_data : Optional [Dict ], ip_adapters_data : Optional [List [Dict ]], ace_plus_plus_data : Optional [Dict ], pulid_data : Optional [Dict ]) -> Dict [str , Any ]:
20632189 request_object = {
20642190 "taskType" : ETaskType .IMAGE_INFERENCE .value ,
@@ -2580,6 +2706,13 @@ def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio
25802706 if provider_dict :
25812707 request_object ["providerSettings" ] = provider_dict
25822708
2709+ def _addTextProviderSettings (self , request_object : Dict [str , Any ], requestText : ITextInference ) -> None :
2710+ if not requestText .providerSettings :
2711+ return
2712+ provider_dict = requestText .providerSettings .to_request_dict ()
2713+ if provider_dict :
2714+ request_object ["providerSettings" ] = provider_dict
2715+
25832716 async def _handleInitialAudioResponse (
25842717 self ,
25852718 task_uuid : str ,
@@ -2717,7 +2850,7 @@ async def _pollResults(
27172850 self ,
27182851 task_uuid : str ,
27192852 number_results : Optional [int ],
2720- ) -> Union [List [IVideo ], List [IVideoToText ], List [IAudio ], List [IImage ], List [I3d ]]:
2853+ ) -> Union [List [IVideo ], List [IVideoToText ], List [IAudio ], List [IImage ], List [I3d ], List [ IText ] ]:
27212854 # Default to 1 if number_results is None
27222855 if number_results is None :
27232856 number_results = 1
@@ -2726,7 +2859,7 @@ async def _pollResults(
27262859 lis = self .globalListener (taskUUID = task_uuid )
27272860
27282861 task_type = None
2729- response_cls : Optional [Union [IVideo , IVideoToText , IAudio , IImage , I3d ]] = None
2862+ response_cls : Optional [Union [IVideo , IVideoToText , IAudio , IImage , I3d , IText ]] = None
27302863 max_polls : int = MAX_POLLS
27312864 polling_delay : int = VIDEO_POLLING_DELAY
27322865 timeout_message : str = f"Polling timeout after { MAX_POLLS } polls"
@@ -2775,6 +2908,13 @@ def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]:
27752908 VIDEO_POLLING_DELAY ,
27762909 f"3d generation timeout after { MAX_POLLS } polls"
27772910 )
2911+ case ETaskType .TEXT_INFERENCE .value :
2912+ return (
2913+ IText ,
2914+ MAX_POLLS ,
2915+ TEXT_POLLING_DELAY ,
2916+ f"Text generation timeout after { MAX_POLLS } polls"
2917+ )
27782918 case _:
27792919 raise ValueError (f"Unsupported task type for polling: { task_type } " )
27802920
0 commit comments