|
33 | 33 | IControlNet, |
34 | 34 | IVideo, |
35 | 35 | IVideoInference, |
| 36 | + IAudio, |
| 37 | + IAudioInference, |
| 38 | + IAudioSettings, |
36 | 39 | IGoogleProviderSettings, |
37 | 40 | IKlingAIProviderSettings, |
38 | 41 | IFrameImage, |
@@ -1616,6 +1619,161 @@ def _processVideoPollingResponse(self, responses: List[Dict[str, Any]]) -> List[ |
1616 | 1619 | def _hasPendingVideos(self, responses: List[Dict[str, Any]]) -> bool: |
1617 | 1620 | return any(response.get("status") == "pending" for response in responses) |
1618 | 1621 |
|
| 1622 | + async def audioInference(self, requestAudio: IAudioInference) -> List[IAudio]: |
| 1623 | + await self.ensureConnection() |
| 1624 | + return await asyncRetry(lambda: self._requestAudio(requestAudio)) |
| 1625 | + |
| 1626 | + async def _requestAudio(self, requestAudio: IAudioInference) -> List[IAudio]: |
| 1627 | + requestAudio.taskUUID = requestAudio.taskUUID or getUUID() |
| 1628 | + request_object = self._buildAudioRequest(requestAudio) |
| 1629 | + await self.send([request_object]) |
| 1630 | + return await self._handleInitialAudioResponse(requestAudio.taskUUID, requestAudio.numberResults) |
| 1631 | + |
| 1632 | + def _buildAudioRequest(self, requestAudio: IAudioInference) -> Dict[str, Any]: |
| 1633 | + request_object = { |
| 1634 | + "deliveryMethod": requestAudio.deliveryMethod, |
| 1635 | + "taskType": ETaskType.AUDIO_INFERENCE.value, |
| 1636 | + "taskUUID": requestAudio.taskUUID, |
| 1637 | + "model": requestAudio.model, |
| 1638 | + "numberResults": requestAudio.numberResults, |
| 1639 | + } |
| 1640 | + |
| 1641 | + # Only add positivePrompt if it's provided |
| 1642 | + if requestAudio.positivePrompt is not None: |
| 1643 | + request_object["positivePrompt"] = requestAudio.positivePrompt.strip() |
| 1644 | + |
| 1645 | + # Only add duration if it's provided and not using composition plan |
| 1646 | + if requestAudio.duration is not None: |
| 1647 | + request_object["duration"] = requestAudio.duration |
| 1648 | + |
| 1649 | + self._addOptionalAudioFields(request_object, requestAudio) |
| 1650 | + self._addAudioSettings(request_object, requestAudio) |
| 1651 | + self._addAudioProviderSettings(request_object, requestAudio) |
| 1652 | + |
| 1653 | + return request_object |
| 1654 | + |
| 1655 | + def _addOptionalAudioFields(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None: |
| 1656 | + optional_fields = [ |
| 1657 | + "outputType", "outputFormat", "includeCost", "uploadEndpoint", "webhookURL" |
| 1658 | + ] |
| 1659 | + |
| 1660 | + for field in optional_fields: |
| 1661 | + value = getattr(requestAudio, field, None) |
| 1662 | + if value is not None: |
| 1663 | + request_object[field] = value |
| 1664 | + |
| 1665 | + def _addAudioSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None: |
| 1666 | + if requestAudio.audioSettings: |
| 1667 | + audio_settings_dict = asdict(requestAudio.audioSettings) |
| 1668 | + # Remove None values |
| 1669 | + audio_settings_dict = {k: v for k, v in audio_settings_dict.items() if v is not None} |
| 1670 | + if audio_settings_dict: |
| 1671 | + request_object["audioSettings"] = audio_settings_dict |
| 1672 | + |
| 1673 | + def _addAudioProviderSettings(self, request_object: Dict[str, Any], requestAudio: IAudioInference) -> None: |
| 1674 | + if not requestAudio.providerSettings: |
| 1675 | + return |
| 1676 | + provider_dict = requestAudio.providerSettings.to_request_dict() |
| 1677 | + if provider_dict: |
| 1678 | + request_object["providerSettings"] = provider_dict |
| 1679 | + |
| 1680 | + async def _handleInitialAudioResponse(self, task_uuid: str, number_results: int) -> List[IAudio]: |
| 1681 | + if number_results == 1: |
| 1682 | + # Single result - wait for completion |
| 1683 | + response = await self._waitForAudioCompletion(task_uuid) |
| 1684 | + return [response] if response else [] |
| 1685 | + else: |
| 1686 | + # Multiple results - use polling |
| 1687 | + return await self._pollForAudioResults(task_uuid, number_results) |
| 1688 | + |
| 1689 | + async def _waitForAudioCompletion(self, task_uuid: str) -> Optional[IAudio]: |
| 1690 | + lis = self.globalListener(taskUUID=task_uuid) |
| 1691 | + |
| 1692 | + def check(resolve: Callable, reject: Callable, *args: Any) -> bool: |
| 1693 | + response = self._globalMessages.get(task_uuid) |
| 1694 | + if response: |
| 1695 | + audio_response = response[0] if isinstance(response, list) else response |
| 1696 | + else: |
| 1697 | + audio_response = response |
| 1698 | + |
| 1699 | + if audio_response and audio_response.get("error"): |
| 1700 | + reject(audio_response) |
| 1701 | + return True |
| 1702 | + |
| 1703 | + if audio_response: |
| 1704 | + del self._globalMessages[task_uuid] |
| 1705 | + resolve(audio_response) |
| 1706 | + return True |
| 1707 | + |
| 1708 | + return False |
| 1709 | + |
| 1710 | + try: |
| 1711 | + response = await getIntervalWithPromise( |
| 1712 | + check, debugKey="audio-inference", timeOutDuration=self._timeout |
| 1713 | + ) |
| 1714 | + lis["destroy"]() |
| 1715 | + |
| 1716 | + if "code" in response: |
| 1717 | + raise RunwareAPIError(response) |
| 1718 | + |
| 1719 | + return self._createAudioFromResponse(response) if response else None |
| 1720 | + except Exception as e: |
| 1721 | + lis["destroy"]() |
| 1722 | + raise e |
| 1723 | + |
| 1724 | + async def _pollForAudioResults(self, task_uuid: str, number_results: int) -> List[IAudio]: |
| 1725 | + completed_results = [] |
| 1726 | + lis = self.globalListener(taskUUID=task_uuid) |
| 1727 | + |
| 1728 | + try: |
| 1729 | + while len(completed_results) < number_results: |
| 1730 | + responses = self._globalMessages.get(task_uuid, []) |
| 1731 | + if not isinstance(responses, list): |
| 1732 | + responses = [responses] if responses else [] |
| 1733 | + |
| 1734 | + processed_responses = self._processAudioPollingResponse(responses) |
| 1735 | + completed_results.extend(processed_responses) |
| 1736 | + |
| 1737 | + if len(completed_results) >= number_results: |
| 1738 | + break |
| 1739 | + |
| 1740 | + await asyncio.sleep(1) # Poll every second |
| 1741 | + |
| 1742 | + # Clean up |
| 1743 | + if task_uuid in self._globalMessages: |
| 1744 | + del self._globalMessages[task_uuid] |
| 1745 | + lis["destroy"]() |
| 1746 | + |
| 1747 | + return [self._createAudioFromResponse(response) for response in completed_results[:number_results]] |
| 1748 | + |
| 1749 | + except Exception as e: |
| 1750 | + lis["destroy"]() |
| 1751 | + raise e |
| 1752 | + |
| 1753 | + def _processAudioPollingResponse(self, responses: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 1754 | + completed_results = [] |
| 1755 | + |
| 1756 | + for response in responses: |
| 1757 | + if response.get("code"): |
| 1758 | + raise RunwareAPIError(response) |
| 1759 | + status = response.get("status") |
| 1760 | + if status == "success": |
| 1761 | + completed_results.append(response) |
| 1762 | + |
| 1763 | + return completed_results |
| 1764 | + |
| 1765 | + def _createAudioFromResponse(self, response: Dict[str, Any]) -> IAudio: |
| 1766 | + return IAudio( |
| 1767 | + taskType=response.get("taskType", ""), |
| 1768 | + taskUUID=response.get("taskUUID", ""), |
| 1769 | + status=response.get("status"), |
| 1770 | + audioUUID=response.get("audioUUID"), |
| 1771 | + audioURL=response.get("audioURL"), |
| 1772 | + audioBase64Data=response.get("audioBase64Data"), |
| 1773 | + audioDataURI=response.get("audioDataURI"), |
| 1774 | + cost=response.get("cost") |
| 1775 | + ) |
| 1776 | + |
1619 | 1777 | def connected(self) -> bool: |
1620 | 1778 | """ |
1621 | 1779 | Check if the current WebSocket connection is active and authenticated. |
|
0 commit comments