From ff1cddb6d78d4945edb0691750f762e5385ac686 Mon Sep 17 00:00:00 2001 From: om mistry Date: Tue, 14 Oct 2025 22:45:25 +0530 Subject: [PATCH] Create a helper function to poll for agent deployment readiness --- src/gradient/resources/agents/agents.py | 103 ++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/gradient/resources/agents/agents.py b/src/gradient/resources/agents/agents.py index a4d32fca..1c94bb53 100644 --- a/src/gradient/resources/agents/agents.py +++ b/src/gradient/resources/agents/agents.py @@ -296,6 +296,57 @@ def retrieve( cast_to=AgentRetrieveResponse, ) + def wait_for_deployment( + self, + uuid: str, + *, + timeout: float = 120.0, + interval: float = 2.0, + raise_on_failed: bool = True, + ) -> AgentRetrieveResponse: + """Poll the agent deployment until it reaches `STATUS_RUNNING` or a terminal failed state. + + Args: + uuid: Agent UUID to poll + timeout: Maximum seconds to wait before raising TimeoutError + interval: Seconds between polls + raise_on_failed: If True, raise RuntimeError when deployment enters a failed state + + Returns: + The final `AgentRetrieveResponse` when status is `STATUS_RUNNING`. + + Raises: + TimeoutError: if the timeout is exceeded + RuntimeError: if deployment enters a failed terminal state and `raise_on_failed` is True + """ + import time + + if not uuid: + raise ValueError("Expected a non-empty value for `uuid`") + + end = time.time() + timeout + failed_states = {"STATUS_FAILED", "STATUS_UNDEPLOYMENT_FAILED", "STATUS_DELETED"} + + while True: + resp = self.retrieve(uuid) + agent = resp.agent + status = None + if agent and agent.deployment and agent.deployment.status: + status = agent.deployment.status + + if status == "STATUS_RUNNING": + return resp + + if status in failed_states: + if raise_on_failed: + raise RuntimeError(f"Agent {uuid} deployment entered failed state: {status}") + return resp + + if time.time() >= end: + raise TimeoutError(f"Timed out waiting for agent {uuid} to be running") + + time.sleep(interval) + def update( self, path_uuid: str, @@ -792,6 +843,58 @@ async def retrieve( cast_to=AgentRetrieveResponse, ) + async def wait_for_deployment( + self, + uuid: str, + *, + timeout: float = 120.0, + interval: float = 2.0, + raise_on_failed: bool = True, + ) -> AgentRetrieveResponse: + """Async poll until agent deployment reaches `STATUS_RUNNING` or a terminal failed state. + + Args: + uuid: Agent UUID to poll + timeout: Maximum seconds to wait before raising TimeoutError + interval: Seconds between polls + raise_on_failed: If True, raise RuntimeError when deployment enters a failed state + + Returns: + The final `AgentRetrieveResponse` when status is `STATUS_RUNNING`. + + Raises: + TimeoutError: if the timeout is exceeded + RuntimeError: if deployment enters a failed terminal state and `raise_on_failed` is True + """ + import asyncio + import time + + if not uuid: + raise ValueError("Expected a non-empty value for `uuid`") + + end = time.time() + timeout + failed_states = {"STATUS_FAILED", "STATUS_UNDEPLOYMENT_FAILED", "STATUS_DELETED"} + + while True: + resp = await self.retrieve(uuid) + agent = resp.agent + status = None + if agent and agent.deployment and agent.deployment.status: + status = agent.deployment.status + + if status == "STATUS_RUNNING": + return resp + + if status in failed_states: + if raise_on_failed: + raise RuntimeError(f"Agent {uuid} deployment entered failed state: {status}") + return resp + + if time.time() >= end: + raise TimeoutError(f"Timed out waiting for agent {uuid} to be running") + + await asyncio.sleep(interval) + async def update( self, path_uuid: str,