Skip to content

Commit 2723ab1

Browse files
authored
Merge pull request #236 from Runware/feature-textInference
Add textInference support
2 parents 36399b5 + 6a88757 commit 2723ab1

3 files changed

Lines changed: 214 additions & 4 deletions

File tree

runware/base.py

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
IVectorize,
5353
I3dInference,
5454
I3d,
55+
ITextInference,
56+
IText,
5557
)
5658
from .types import IImage, IError, SdkType, ListenerType
5759
from .utils import (
@@ -75,6 +77,7 @@
7577
process_image,
7678
createAsyncTaskResponse,
7779
VIDEO_INITIAL_TIMEOUT,
80+
TEXT_INITIAL_TIMEOUT,
7881
VIDEO_POLLING_DELAY,
7982
WEBHOOK_TIMEOUT,
8083
IMAGE_INFERENCE_TIMEOUT,
@@ -84,6 +87,7 @@
8487
MODEL_UPLOAD_TIMEOUT,
8588
IMAGE_INITIAL_TIMEOUT,
8689
IMAGE_POLLING_DELAY,
90+
TEXT_POLLING_DELAY,
8791
AUDIO_INITIAL_TIMEOUT,
8892
AUDIO_INFERENCE_TIMEOUT,
8993
AUDIO_POLLING_DELAY,
@@ -1880,18 +1884,25 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync
18801884
await self.ensureConnection()
18811885
return await self._request3d(request3d)
18821886

1887+
async def textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
1888+
return await self._retry_with_reconnect(self._textInference, requestText)
1889+
1890+
async def _textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
1891+
await self.ensureConnection()
1892+
return await self._requestText(requestText)
1893+
18831894
async def getResponse(
18841895
self,
18851896
taskUUID: str,
18861897
numberResults: Optional[int] = 1,
1887-
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d]]:
1898+
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]:
18881899
return await self._retry_with_reconnect(self._getResponse, taskUUID, numberResults)
18891900

18901901
async def _getResponse(
18911902
self,
18921903
taskUUID: str,
18931904
numberResults: Optional[int] = 1,
1894-
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d]]:
1905+
) -> Union[List[IVideo], List[IAudio], List[IVideoToText], List[IImage], List[I3d], List[IText]]:
18951906
await self.ensureConnection()
18961907

18971908
return await self._pollResults(
@@ -2059,6 +2070,121 @@ async def _request3d(self, request3d: I3dInference) -> Union[List[I3d], IAsyncTa
20592070
"3d-inference-initial",
20602071
)
20612072

2073+
def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]:
2074+
request_object: Dict[str, Any] = {
2075+
"taskType": ETaskType.TEXT_INFERENCE.value,
2076+
"taskUUID": requestText.taskUUID,
2077+
"model": requestText.model,
2078+
"deliveryMethod": requestText.deliveryMethod,
2079+
"messages": [asdict(m) for m in requestText.messages],
2080+
}
2081+
if requestText.maxTokens is not None:
2082+
request_object["maxTokens"] = requestText.maxTokens
2083+
if requestText.temperature is not None:
2084+
request_object["temperature"] = requestText.temperature
2085+
if requestText.topP is not None:
2086+
request_object["topP"] = requestText.topP
2087+
if requestText.topK is not None:
2088+
request_object["topK"] = requestText.topK
2089+
if requestText.seed is not None:
2090+
request_object["seed"] = requestText.seed
2091+
if requestText.stopSequences is not None:
2092+
request_object["stopSequences"] = requestText.stopSequences
2093+
if requestText.includeCost is not None:
2094+
request_object["includeCost"] = requestText.includeCost
2095+
self._addTextProviderSettings(request_object, requestText)
2096+
return request_object
2097+
2098+
async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
2099+
requestText.taskUUID = requestText.taskUUID or getUUID()
2100+
request_object = self._buildTextRequest(requestText)
2101+
await self.send([request_object])
2102+
return await self._handleInitialTextResponse(
2103+
requestText.taskUUID,
2104+
requestText.deliveryMethod,
2105+
"text-inference-initial",
2106+
)
2107+
2108+
async def _handleInitialTextResponse(
2109+
self,
2110+
task_uuid: str,
2111+
delivery_method: Union[str, EDeliveryMethod] = EDeliveryMethod.SYNC,
2112+
debug_key: str = "text-inference-initial",
2113+
) -> Union[List[IText], IAsyncTaskResponse]:
2114+
lis = self.globalListener(taskUUID=task_uuid)
2115+
delivery_method_enum = delivery_method if isinstance(delivery_method, EDeliveryMethod) else EDeliveryMethod(delivery_method)
2116+
2117+
async def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool:
2118+
if not self.connected() or not self.isWebsocketReadyState():
2119+
reject(ConnectionError(
2120+
f"Connection lost while waiting for text response | "
2121+
f"TaskUUID: {task_uuid} | "
2122+
f"Delivery method: {delivery_method_enum}"
2123+
))
2124+
return True
2125+
2126+
async with self._messages_lock:
2127+
response_list = self._globalMessages.get(task_uuid, [])
2128+
2129+
if not response_list:
2130+
return False
2131+
2132+
response = response_list[0]
2133+
2134+
if self._is_error_response(response):
2135+
del self._globalMessages[task_uuid]
2136+
raise RunwareAPIError(response)
2137+
2138+
if response.get("status") == "success" or response.get("text") is not None:
2139+
del self._globalMessages[task_uuid]
2140+
resolve([response])
2141+
return True
2142+
2143+
if delivery_method_enum is EDeliveryMethod.ASYNC:
2144+
del self._globalMessages[task_uuid]
2145+
async_response = createAsyncTaskResponse(response)
2146+
resolve([async_response])
2147+
return True
2148+
2149+
return False
2150+
2151+
try:
2152+
initial_response = await getIntervalWithPromise(
2153+
check_initial_response,
2154+
debugKey=debug_key,
2155+
timeOutDuration=TIMEOUT_DURATION if delivery_method_enum is EDeliveryMethod.SYNC else TEXT_INITIAL_TIMEOUT,
2156+
)
2157+
except RunwareAPIError:
2158+
raise
2159+
except Exception as e:
2160+
if not self.connected() or not self.isWebsocketReadyState():
2161+
raise ConnectionError(
2162+
f"Connection lost while waiting for text response | "
2163+
f"TaskUUID: {task_uuid} | "
2164+
f"Delivery method: {delivery_method_enum}"
2165+
)
2166+
if delivery_method_enum is EDeliveryMethod.SYNC:
2167+
error_msg = (
2168+
f"Timeout waiting for text generation | "
2169+
f"TaskUUID: {task_uuid} | "
2170+
f"Timeout: {TIMEOUT_DURATION}ms | "
2171+
f"Original error: {str(e)}"
2172+
)
2173+
raise ConnectionError(error_msg)
2174+
initial_response = None
2175+
finally:
2176+
lis["destroy"]()
2177+
2178+
if not initial_response or len(initial_response) == 0:
2179+
raise ConnectionError(
2180+
f"No initial response received for text generation | delivery_method={delivery_method_enum} | taskUUID={task_uuid}"
2181+
)
2182+
2183+
if isinstance(initial_response[0], IAsyncTaskResponse):
2184+
return initial_response[0]
2185+
2186+
return instantiateDataclassList(IText, initial_response)
2187+
20622188
def _buildImageRequest(self, requestImage: IImageInference, prompt: Optional[str], control_net_data_dicts: List[Dict], instant_id_data: Optional[Dict], ip_adapters_data: Optional[List[Dict]], ace_plus_plus_data: Optional[Dict], pulid_data: Optional[Dict]) -> Dict[str, Any]:
20632189
request_object = {
20642190
"taskType": ETaskType.IMAGE_INFERENCE.value,
@@ -2580,6 +2706,13 @@ def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio
25802706
if provider_dict:
25812707
request_object["providerSettings"] = provider_dict
25822708

2709+
def _addTextProviderSettings(self, request_object: Dict[str, Any], requestText: ITextInference) -> None:
2710+
if not requestText.providerSettings:
2711+
return
2712+
provider_dict = requestText.providerSettings.to_request_dict()
2713+
if provider_dict:
2714+
request_object["providerSettings"] = provider_dict
2715+
25832716
async def _handleInitialAudioResponse(
25842717
self,
25852718
task_uuid: str,
@@ -2717,7 +2850,7 @@ async def _pollResults(
27172850
self,
27182851
task_uuid: str,
27192852
number_results: Optional[int],
2720-
) -> Union[List[IVideo], List[IVideoToText], List[IAudio], List[IImage], List[I3d]]:
2853+
) -> Union[List[IVideo], List[IVideoToText], List[IAudio], List[IImage], List[I3d], List[IText]]:
27212854
# Default to 1 if number_results is None
27222855
if number_results is None:
27232856
number_results = 1
@@ -2726,7 +2859,7 @@ async def _pollResults(
27262859
lis = self.globalListener(taskUUID=task_uuid)
27272860

27282861
task_type = None
2729-
response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d]] = None
2862+
response_cls: Optional[Union[IVideo, IVideoToText, IAudio, IImage, I3d, IText]] = None
27302863
max_polls: int = MAX_POLLS
27312864
polling_delay: int = VIDEO_POLLING_DELAY
27322865
timeout_message: str = f"Polling timeout after {MAX_POLLS} polls"
@@ -2775,6 +2908,13 @@ def configure_from_task_type(task_type: Optional[str]) -> Optional[tuple]:
27752908
VIDEO_POLLING_DELAY,
27762909
f"3d generation timeout after {MAX_POLLS} polls"
27772910
)
2911+
case ETaskType.TEXT_INFERENCE.value:
2912+
return (
2913+
IText,
2914+
MAX_POLLS,
2915+
TEXT_POLLING_DELAY,
2916+
f"Text generation timeout after {MAX_POLLS} polls"
2917+
)
27782918
case _:
27792919
raise ValueError(f"Unsupported task type for polling: {task_type}")
27802920

runware/types.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ETaskType(Enum):
4343
MODEL_SEARCH = "modelSearch"
4444
VIDEO_INFERENCE = "videoInference"
4545
INFERENCE_3D = "3dInference"
46+
TEXT_INFERENCE = "textInference"
4647
AUDIO_INFERENCE = "audioInference"
4748
VIDEO_CAPTION = "caption"
4849
MEDIA_STORAGE = "mediaStorage"
@@ -1470,6 +1471,59 @@ class I3d:
14701471
outputs: Optional[I3dOutput] = None
14711472

14721473

1474+
@dataclass
1475+
class ITextInferenceMessage:
1476+
role: str
1477+
content: str
1478+
1479+
1480+
@dataclass
1481+
class ITextInferenceUsage:
1482+
promptTokens: Optional[int] = None
1483+
completionTokens: Optional[int] = None
1484+
totalTokens: Optional[int] = None
1485+
thinkingTokens: Optional[int] = None
1486+
1487+
1488+
@dataclass
1489+
class IGoogleTextProviderSettings(BaseProviderSettings):
1490+
thinkingLevel: Optional[str] = None
1491+
1492+
@property
1493+
def provider_key(self) -> str:
1494+
return "google"
1495+
1496+
1497+
TextProviderSettings = IGoogleTextProviderSettings
1498+
1499+
1500+
@dataclass
1501+
class ITextInference:
1502+
model: str
1503+
messages: List[ITextInferenceMessage]
1504+
taskUUID: Optional[str] = None
1505+
deliveryMethod: str = "sync"
1506+
maxTokens: Optional[int] = None
1507+
temperature: Optional[float] = None
1508+
topP: Optional[float] = None
1509+
topK: Optional[int] = None
1510+
seed: Optional[int] = None
1511+
stopSequences: Optional[List[str]] = None
1512+
includeCost: Optional[bool] = None
1513+
providerSettings: Optional[TextProviderSettings] = None
1514+
1515+
1516+
@dataclass
1517+
class IText:
1518+
taskType: str
1519+
taskUUID: str
1520+
text: Optional[str] = None
1521+
finishReason: Optional[str] = None
1522+
usage: Optional[ITextInferenceUsage] = None
1523+
cost: Optional[float] = None
1524+
status: Optional[str] = None
1525+
1526+
14731527
@dataclass
14741528
class IAudio:
14751529
taskType: str

runware/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@
117117
30000
118118
))
119119

120+
# Text initial response timeout (milliseconds)
121+
# Maximum time to wait for the initial text response before falling back to async handling
122+
# Used in: _handleInitialTextResponse() for async delivery method
123+
TEXT_INITIAL_TIMEOUT = int(os.environ.get(
124+
"RUNWARE_TEXT_INITIAL_TIMEOUT",
125+
30000
126+
))
127+
120128
# Audio generation timeout (milliseconds)
121129
# Maximum time to wait for audio generation completion
122130
# Used in: _waitForAudioCompletion() for single audio generation
@@ -149,6 +157,14 @@
149157
1000
150158
))
151159

160+
# Text polling delay (milliseconds)
161+
# Delay between consecutive polling requests for text generation status
162+
# Used in: _pollResults() for checking textInference task progress
163+
TEXT_POLLING_DELAY = int(os.environ.get(
164+
"RUNWARE_TEXT_POLLING_DELAY",
165+
1000
166+
))
167+
152168
# Prompt enhancement timeout (milliseconds)
153169
# Maximum time to wait for prompt enhancement completion
154170
# Used in: promptEnhance() for enhancing text prompts

0 commit comments

Comments
 (0)