Skip to content

Commit e36fc33

Browse files
authored
Merge pull request #153 from Runware/feature-audioInference
Add audio inference support with ElevenLabs provider settings
2 parents 4e750f9 + 1205e43 commit e36fc33

2 files changed

Lines changed: 229 additions & 0 deletions

File tree

runware/base.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
IControlNet,
3434
IVideo,
3535
IVideoInference,
36+
IAudio,
37+
IAudioInference,
38+
IAudioSettings,
3639
IGoogleProviderSettings,
3740
IKlingAIProviderSettings,
3841
IFrameImage,
@@ -1616,6 +1619,161 @@ def _processVideoPollingResponse(self, responses: List[Dict[str, Any]]) -> List[
16161619
def _hasPendingVideos(self, responses: List[Dict[str, Any]]) -> bool:
16171620
return any(response.get("status") == "pending" for response in responses)
16181621

1622+
async def audioInference(self, requestAudio: IAudioInference) -> List[IAudio]:
1623+
await self.ensureConnection()
1624+
return await asyncRetry(lambda: self._requestAudio(requestAudio))
1625+
1626+
async def _requestAudio(self, requestAudio: IAudioInference) -> List[IAudio]:
1627+
requestAudio.taskUUID = requestAudio.taskUUID or getUUID()
1628+
request_object = self._buildAudioRequest(requestAudio)
1629+
await self.send([request_object])
1630+
return await self._handleInitialAudioResponse(requestAudio.taskUUID, requestAudio.numberResults)
1631+
1632+
def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]:
1633+
request_object = {
1634+
"deliveryMethod": requestAudio.deliveryMethod,
1635+
"taskType": ETaskType.AUDIO_INFERENCE.value,
1636+
"taskUUID": requestAudio.taskUUID,
1637+
"model": requestAudio.model,
1638+
"numberResults": requestAudio.numberResults,
1639+
}
1640+
1641+
# Only add positivePrompt if it's provided
1642+
if requestAudio.positivePrompt is not None:
1643+
request_object["positivePrompt"] = requestAudio.positivePrompt.strip()
1644+
1645+
# Only add duration if it's provided and not using composition plan
1646+
if requestAudio.duration is not None:
1647+
request_object["duration"] = requestAudio.duration
1648+
1649+
self._addOptionalAudioFields(request_object, requestAudio)
1650+
self._addAudioSettings(request_object, requestAudio)
1651+
self._addAudioProviderSettings(request_object, requestAudio)
1652+
1653+
return request_object
1654+
1655+
def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
1656+
optional_fields = [
1657+
"outputType", "outputFormat", "includeCost", "uploadEndpoint", "webhookURL"
1658+
]
1659+
1660+
for field in optional_fields:
1661+
value = getattr(requestAudio, field, None)
1662+
if value is not None:
1663+
request_object[field] = value
1664+
1665+
def _addAudioSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
1666+
if requestAudio.audioSettings:
1667+
audio_settings_dict = asdict(requestAudio.audioSettings)
1668+
# Remove None values
1669+
audio_settings_dict = {k: v for k, v in audio_settings_dict.items() if v is not None}
1670+
if audio_settings_dict:
1671+
request_object["audioSettings"] = audio_settings_dict
1672+
1673+
def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None:
1674+
if not requestAudio.providerSettings:
1675+
return
1676+
provider_dict = requestAudio.providerSettings.to_request_dict()
1677+
if provider_dict:
1678+
request_object["providerSettings"] = provider_dict
1679+
1680+
async def _handleInitialAudioResponse(self, task_uuid: str, number_results: int) -> List[IAudio]:
1681+
if number_results == 1:
1682+
# Single result - wait for completion
1683+
response = await self._waitForAudioCompletion(task_uuid)
1684+
return [response] if response else []
1685+
else:
1686+
# Multiple results - use polling
1687+
return await self._pollForAudioResults(task_uuid, number_results)
1688+
1689+
async def _waitForAudioCompletion(self, task_uuid: str) -> Optional[IAudio]:
1690+
lis = self.globalListener(taskUUID=task_uuid)
1691+
1692+
def check(resolve: Callable, reject: Callable, *args: Any) -> bool:
1693+
response = self._globalMessages.get(task_uuid)
1694+
if response:
1695+
audio_response = response[0] if isinstance(response, list) else response
1696+
else:
1697+
audio_response = response
1698+
1699+
if audio_response and audio_response.get("error"):
1700+
reject(audio_response)
1701+
return True
1702+
1703+
if audio_response:
1704+
del self._globalMessages[task_uuid]
1705+
resolve(audio_response)
1706+
return True
1707+
1708+
return False
1709+
1710+
try:
1711+
response = await getIntervalWithPromise(
1712+
check, debugKey="audio-inference", timeOutDuration=self._timeout
1713+
)
1714+
lis["destroy"]()
1715+
1716+
if "code" in response:
1717+
raise RunwareAPIError(response)
1718+
1719+
return self._createAudioFromResponse(response) if response else None
1720+
except Exception as e:
1721+
lis["destroy"]()
1722+
raise e
1723+
1724+
async def _pollForAudioResults(self, task_uuid: str, number_results: int) -> List[IAudio]:
1725+
completed_results = []
1726+
lis = self.globalListener(taskUUID=task_uuid)
1727+
1728+
try:
1729+
while len(completed_results) < number_results:
1730+
responses = self._globalMessages.get(task_uuid, [])
1731+
if not isinstance(responses, list):
1732+
responses = [responses] if responses else []
1733+
1734+
processed_responses = self._processAudioPollingResponse(responses)
1735+
completed_results.extend(processed_responses)
1736+
1737+
if len(completed_results) >= number_results:
1738+
break
1739+
1740+
await asyncio.sleep(1) # Poll every second
1741+
1742+
# Clean up
1743+
if task_uuid in self._globalMessages:
1744+
del self._globalMessages[task_uuid]
1745+
lis["destroy"]()
1746+
1747+
return [self._createAudioFromResponse(response) for response in completed_results[:number_results]]
1748+
1749+
except Exception as e:
1750+
lis["destroy"]()
1751+
raise e
1752+
1753+
def _processAudioPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
1754+
completed_results = []
1755+
1756+
for response in responses:
1757+
if response.get("code"):
1758+
raise RunwareAPIError(response)
1759+
status = response.get("status")
1760+
if status == "success":
1761+
completed_results.append(response)
1762+
1763+
return completed_results
1764+
1765+
def _createAudioFromResponse(self, response: Dict[str, Any]) -> IAudio:
1766+
return IAudio(
1767+
taskType=response.get("taskType", ""),
1768+
taskUUID=response.get("taskUUID", ""),
1769+
status=response.get("status"),
1770+
audioUUID=response.get("audioUUID"),
1771+
audioURL=response.get("audioURL"),
1772+
audioBase64Data=response.get("audioBase64Data"),
1773+
audioDataURI=response.get("audioDataURI"),
1774+
cost=response.get("cost")
1775+
)
1776+
16191777
def connected(self) -> bool:
16201778
"""
16211779
Check if the current WebSocket connection is active and authenticated.

runware/types.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class ETaskType(Enum):
3939
MODEL_UPLOAD = "modelUpload"
4040
MODEL_SEARCH = "modelSearch"
4141
VIDEO_INFERENCE = "videoInference"
42+
AUDIO_INFERENCE = "audioInference"
4243
GET_RESPONSE = "getResponse"
4344

4445

@@ -98,6 +99,8 @@ class EOpenPosePreProcessor(Enum):
9899
# Define the types using Literal
99100
IOutputType = Literal["base64Data", "dataURI", "URL"]
100101
IOutputFormat = Literal["JPG", "PNG", "WEBP"]
102+
IAudioOutputType = Literal["base64Data", "dataURI", "URL"]
103+
IAudioOutputFormat = Literal["MP3"]
101104

102105

103106
@dataclass
@@ -481,6 +484,33 @@ class IImageCaption:
481484
template: Optional[str] = None
482485

483486

487+
@dataclass
488+
class IAudioSettings:
489+
sampleRate: Optional[int] = None # Min: 8000, Max: 48000, Default: 44100
490+
bitrate: Optional[int] = None # Min: 32, Max: 320, Default: 128
491+
492+
493+
@dataclass
494+
class IElevenLabsCompositionSection:
495+
sectionName: str # 1-100 characters
496+
positiveLocalStyles: List[str] # Styles that should be present in this section
497+
negativeLocalStyles: List[str] # Styles that should not be present in this section
498+
lines: List[str] # Lyrics of the section
499+
duration: Optional[int] = None # Duration in seconds (3-120s)
500+
501+
502+
@dataclass
503+
class IElevenLabsCompositionPlan:
504+
positiveGlobalStyles: List[str] # Styles that should be present in the entire song
505+
negativeGlobalStyles: List[str] # Styles that should not be present in the entire song
506+
sections: List[IElevenLabsCompositionSection] # Sections of the song
507+
508+
509+
@dataclass
510+
class IElevenLabsMusicSettings:
511+
compositionPlan: IElevenLabsCompositionPlan # Music composition structure
512+
513+
484514
@dataclass
485515
class IImageToText:
486516
taskType: ETaskType
@@ -733,8 +763,19 @@ def provider_key(self) -> str:
733763
return "vidu"
734764

735765

766+
@dataclass
767+
class IElevenLabsProviderSettings(BaseProviderSettings):
768+
music: Optional[IElevenLabsMusicSettings] = None
769+
770+
@property
771+
def provider_key(self) -> str:
772+
return "elevenlabs"
773+
774+
736775
VideoProviderSettings = IKlingAIProviderSettings | IGoogleProviderSettings | IMinimaxProviderSettings | IBytedanceProviderSettings | IPixverseProviderSettings | IViduProviderSettings
737776

777+
AudioProviderSettings = IElevenLabsProviderSettings
778+
738779
@dataclass
739780
class IVideoInference:
740781
model: str
@@ -762,6 +803,24 @@ class IVideoInference:
762803
providerSettings: Optional[VideoProviderSettings] = None
763804
speech: Optional[IPixverseSpeechSettings] = None
764805

806+
807+
@dataclass
808+
class IAudioInference:
809+
model: str
810+
positivePrompt: Optional[str] = None # Optional when using composition plan
811+
duration: Optional[float] = None # Min: 10, Max: 300 - Optional when using composition plan
812+
taskUUID: Optional[str] = None
813+
outputType: Optional[IAudioOutputType] = None
814+
outputFormat: Optional[IAudioOutputFormat] = None
815+
audioSettings: Optional[IAudioSettings] = None
816+
includeCost: Optional[bool] = None
817+
numberResults: Optional[int] = 1
818+
deliveryMethod: str = "sync" # "sync" | "async"
819+
uploadEndpoint: Optional[str] = None
820+
webhookURL: Optional[str] = None
821+
providerSettings: Optional[AudioProviderSettings] = None # ElevenLabs provider settings
822+
823+
765824
@dataclass
766825
class IVideo:
767826
taskType: str
@@ -773,6 +832,18 @@ class IVideo:
773832
seed: Optional[int] = None
774833

775834

835+
@dataclass
836+
class IAudio:
837+
taskType: str
838+
taskUUID: str
839+
status: Optional[str] = None
840+
audioUUID: Optional[str] = None
841+
audioURL: Optional[str] = None
842+
audioBase64Data: Optional[str] = None
843+
audioDataURI: Optional[str] = None
844+
cost: Optional[float] = None
845+
846+
776847
# The GetWithPromiseCallBackType is defined using the Callable type from the typing module. It represents a function that takes a dictionary
777848
# with specific keys and returns either a boolean or None.
778849
# The dictionary should have the following keys:

0 commit comments

Comments
 (0)