From 59ab161e555e9cfdf2840caa5185de1f3a0c1d58 Mon Sep 17 00:00:00 2001 From: AKSHAT ANAND Date: Tue, 21 Oct 2025 12:41:25 +0000 Subject: [PATCH] feat: add wait_until_database_online helper for knowledge base database polling - Add wait_until_database_online method to KnowledgeBasesResource and AsyncKnowledgeBasesResource - Polls knowledge base database_status until it reaches ONLINE or encounters terminal failure - Implements configurable timeout and poll_interval parameters - Add two new exception types: KnowledgeBaseDatabaseError and KnowledgeBaseDatabaseTimeoutError - Expose new method through WithRawResponse and WithStreamingResponse wrappers - Add comprehensive unit tests (8 tests: 4 sync + 4 async) covering success, timeout, failure, and validation scenarios - Follows the same pattern as agents.wait_until_ready for consistency Closes #42 --- examples/agent_wait_until_ready.py | 28 ++--- src/gradient/_exceptions.py | 18 +++ .../knowledge_bases/knowledge_bases.py | 108 ++++++++++++++++ tests/api_resources/test_agents.py | 46 +++---- tests/api_resources/test_knowledge_bases.py | 119 ++++++++++++++++++ 5 files changed, 283 insertions(+), 36 deletions(-) diff --git a/examples/agent_wait_until_ready.py b/examples/agent_wait_until_ready.py index 3ea7b4a3..df8c8cc6 100644 --- a/examples/agent_wait_until_ready.py +++ b/examples/agent_wait_until_ready.py @@ -24,7 +24,7 @@ if agent_id: print(f"Agent created with ID: {agent_id}") print("Waiting for agent to be ready...") - + try: # Wait for the agent to be deployed and ready # This will poll the agent status every 5 seconds (default) @@ -32,24 +32,24 @@ ready_agent = client.agents.wait_until_ready( agent_id, poll_interval=5.0, # Check every 5 seconds - timeout=300.0, # Wait up to 5 minutes + timeout=300.0, # Wait up to 5 minutes ) - + if ready_agent.agent and ready_agent.agent.deployment: print(f"Agent is ready! Status: {ready_agent.agent.deployment.status}") print(f"Agent URL: {ready_agent.agent.url}") - + # Now you can use the agent # ... - + except AgentDeploymentError as e: print(f"Agent deployment failed: {e}") print(f"Failed status: {e.status}") - + except AgentDeploymentTimeoutError as e: print(f"Agent deployment timed out: {e}") print(f"Agent ID: {e.agent_id}") - + except Exception as e: print(f"Unexpected error: {e}") @@ -60,7 +60,7 @@ async def main() -> None: async_client = AsyncGradient() - + # Create a new agent agent_response = await async_client.agents.create( name="My Async Agent", @@ -68,13 +68,13 @@ async def main() -> None: model_uuid="", region="nyc1", ) - + agent_id = agent_response.agent.uuid if agent_response.agent else None - + if agent_id: print(f"Agent created with ID: {agent_id}") print("Waiting for agent to be ready...") - + try: # Wait for the agent to be deployed and ready (async) ready_agent = await async_client.agents.wait_until_ready( @@ -82,15 +82,15 @@ async def main() -> None: poll_interval=5.0, timeout=300.0, ) - + if ready_agent.agent and ready_agent.agent.deployment: print(f"Agent is ready! Status: {ready_agent.agent.deployment.status}") print(f"Agent URL: {ready_agent.agent.url}") - + except AgentDeploymentError as e: print(f"Agent deployment failed: {e}") print(f"Failed status: {e.status}") - + except AgentDeploymentTimeoutError as e: print(f"Agent deployment timed out: {e}") print(f"Agent ID: {e.agent_id}") diff --git a/src/gradient/_exceptions.py b/src/gradient/_exceptions.py index 0ced4aba..1d8533eb 100644 --- a/src/gradient/_exceptions.py +++ b/src/gradient/_exceptions.py @@ -15,6 +15,8 @@ "UnprocessableEntityError", "RateLimitError", "InternalServerError", + "KnowledgeBaseDatabaseError", + "KnowledgeBaseDatabaseTimeoutError", "AgentDeploymentError", "AgentDeploymentTimeoutError", ] @@ -124,3 +126,19 @@ class AgentDeploymentTimeoutError(GradientError): def __init__(self, message: str, agent_id: str) -> None: super().__init__(message) self.agent_id = agent_id + + +class KnowledgeBaseDatabaseError(GradientError): + """Raised when a knowledge base database creation fails.""" + + def __init__(self, message: str, status: str) -> None: + super().__init__(message) + self.status = status + + +class KnowledgeBaseDatabaseTimeoutError(GradientError): + """Raised when waiting for a knowledge base database creation times out.""" + + def __init__(self, message: str, knowledge_base_id: str) -> None: + super().__init__(message) + self.knowledge_base_id = knowledge_base_id diff --git a/src/gradient/resources/knowledge_bases/knowledge_bases.py b/src/gradient/resources/knowledge_bases/knowledge_bases.py index 00fa0659..adaa49e4 100644 --- a/src/gradient/resources/knowledge_bases/knowledge_bases.py +++ b/src/gradient/resources/knowledge_bases/knowledge_bases.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from typing import Iterable import httpx @@ -330,6 +331,56 @@ def delete( cast_to=KnowledgeBaseDeleteResponse, ) + def wait_until_database_online( + self, + uuid: str, + *, + timeout: float = 300.0, + poll_interval: float = 5.0, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """Wait for a knowledge base's associated database to reach ONLINE. + + This polls `retrieve` until `database_status` equals "ONLINE", or raises + on terminal failure or timeout. + """ + from ..._exceptions import KnowledgeBaseDatabaseError, KnowledgeBaseDatabaseTimeoutError + + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + + while True: + kb_response = self.retrieve( + uuid, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body + ) + + status = kb_response.database_status if kb_response else None + + # Success + if status == "ONLINE": + return kb_response + + # Failure cases - treat some terminal statuses as failures + if status in ("DECOMMISSIONED", "UNHEALTHY"): + raise KnowledgeBaseDatabaseError( + f"Knowledge base database creation failed with status: {status}", status=status + ) + + # Timeout + elapsed_time = time.time() - start_time + if elapsed_time >= timeout: + current_status = status or "UNKNOWN" + raise KnowledgeBaseDatabaseTimeoutError( + f"Knowledge base database did not reach ONLINE within {timeout} seconds. Current status: {current_status}", + knowledge_base_id=uuid, + ) + + time.sleep(poll_interval) + class AsyncKnowledgeBasesResource(AsyncAPIResource): @cached_property @@ -618,6 +669,51 @@ async def delete( cast_to=KnowledgeBaseDeleteResponse, ) + async def wait_until_database_online( + self, + uuid: str, + *, + timeout: float = 300.0, + poll_interval: float = 5.0, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """Async version of `wait_until_database_online`.""" + import asyncio + + from ..._exceptions import KnowledgeBaseDatabaseError, KnowledgeBaseDatabaseTimeoutError + + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + + while True: + kb_response = await self.retrieve( + uuid, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body + ) + + status = kb_response.database_status if kb_response else None + + if status == "ONLINE": + return kb_response + + if status in ("DECOMMISSIONED", "UNHEALTHY"): + raise KnowledgeBaseDatabaseError( + f"Knowledge base database creation failed with status: {status}", status=status + ) + + elapsed_time = time.time() - start_time + if elapsed_time >= timeout: + current_status = status or "UNKNOWN" + raise KnowledgeBaseDatabaseTimeoutError( + f"Knowledge base database did not reach ONLINE within {timeout} seconds. Current status: {current_status}", + knowledge_base_id=uuid, + ) + + await asyncio.sleep(poll_interval) + class KnowledgeBasesResourceWithRawResponse: def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: @@ -638,6 +734,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_until_database_online = to_raw_response_wrapper( + knowledge_bases.wait_until_database_online, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithRawResponse: @@ -667,6 +766,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_until_database_online = async_to_raw_response_wrapper( + knowledge_bases.wait_until_database_online, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse: @@ -696,6 +798,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_until_database_online = to_streamed_response_wrapper( + knowledge_bases.wait_until_database_online, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithStreamingResponse: @@ -725,6 +830,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_until_database_online = async_to_streamed_response_wrapper( + knowledge_bases.wait_until_database_online, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse: diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index 5777c3ea..1ba3e093 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -368,9 +368,10 @@ def test_path_params_update_status(self, client: Gradient) -> None: def test_method_wait_until_ready(self, client: Gradient, respx_mock: Any) -> None: """Test successful wait_until_ready when agent becomes ready.""" agent_uuid = "test-agent-id" - + # Create side effect that returns different responses call_count = [0] + def get_response(_: httpx.Request) -> httpx.Response: call_count[0] += 1 if call_count[0] == 1: @@ -395,9 +396,9 @@ def get_response(_: httpx.Request) -> httpx.Response: } }, ) - + respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock(side_effect=get_response) - + agent = client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) assert_matches_type(AgentRetrieveResponse, agent, path=["response"]) assert agent.agent is not None @@ -408,9 +409,9 @@ def get_response(_: httpx.Request) -> httpx.Response: def test_wait_until_ready_timeout(self, client: Gradient, respx_mock: Any) -> None: """Test that wait_until_ready raises timeout error.""" from gradient._exceptions import AgentDeploymentTimeoutError - + agent_uuid = "test-agent-id" - + # Mock always returns deploying respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -423,10 +424,10 @@ def test_wait_until_ready_timeout(self, client: Gradient, respx_mock: Any) -> No }, ) ) - + with pytest.raises(AgentDeploymentTimeoutError) as exc_info: client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=0.5) - + assert "did not reach STATUS_RUNNING within" in str(exc_info.value) assert exc_info.value.agent_id == agent_uuid @@ -434,9 +435,9 @@ def test_wait_until_ready_timeout(self, client: Gradient, respx_mock: Any) -> No def test_wait_until_ready_deployment_failed(self, client: Gradient, respx_mock: Any) -> None: """Test that wait_until_ready raises error on deployment failure.""" from gradient._exceptions import AgentDeploymentError - + agent_uuid = "test-agent-id" - + # Mock returns failed status respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -449,10 +450,10 @@ def test_wait_until_ready_deployment_failed(self, client: Gradient, respx_mock: }, ) ) - + with pytest.raises(AgentDeploymentError) as exc_info: client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) - + assert "deployment failed with status: STATUS_FAILED" in str(exc_info.value) assert exc_info.value.status == "STATUS_FAILED" @@ -810,9 +811,10 @@ async def test_path_params_update_status(self, async_client: AsyncGradient) -> N async def test_method_wait_until_ready(self, async_client: AsyncGradient, respx_mock: Any) -> None: """Test successful async wait_until_ready when agent becomes ready.""" agent_uuid = "test-agent-id" - + # Create side effect that returns different responses call_count = [0] + def get_response(_: httpx.Request) -> httpx.Response: call_count[0] += 1 if call_count[0] == 1: @@ -837,9 +839,9 @@ def get_response(_: httpx.Request) -> httpx.Response: } }, ) - + respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock(side_effect=get_response) - + agent = await async_client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) assert_matches_type(AgentRetrieveResponse, agent, path=["response"]) assert agent.agent is not None @@ -850,9 +852,9 @@ def get_response(_: httpx.Request) -> httpx.Response: async def test_wait_until_ready_timeout(self, async_client: AsyncGradient, respx_mock: Any) -> None: """Test that async wait_until_ready raises timeout error.""" from gradient._exceptions import AgentDeploymentTimeoutError - + agent_uuid = "test-agent-id" - + # Mock always returns deploying respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -865,10 +867,10 @@ async def test_wait_until_ready_timeout(self, async_client: AsyncGradient, respx }, ) ) - + with pytest.raises(AgentDeploymentTimeoutError) as exc_info: await async_client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=0.5) - + assert "did not reach STATUS_RUNNING within" in str(exc_info.value) assert exc_info.value.agent_id == agent_uuid @@ -876,9 +878,9 @@ async def test_wait_until_ready_timeout(self, async_client: AsyncGradient, respx async def test_wait_until_ready_deployment_failed(self, async_client: AsyncGradient, respx_mock: Any) -> None: """Test that async wait_until_ready raises error on deployment failure.""" from gradient._exceptions import AgentDeploymentError - + agent_uuid = "test-agent-id" - + # Mock returns failed status respx_mock.get(f"/v2/gen-ai/agents/{agent_uuid}").mock( return_value=httpx.Response( @@ -891,9 +893,9 @@ async def test_wait_until_ready_deployment_failed(self, async_client: AsyncGradi }, ) ) - + with pytest.raises(AgentDeploymentError) as exc_info: await async_client.agents.wait_until_ready(agent_uuid, poll_interval=0.1, timeout=10.0) - + assert "deployment failed with status: STATUS_FAILED" in str(exc_info.value) assert exc_info.value.status == "STATUS_FAILED" diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 62965775..ae0b1cd8 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -5,6 +5,7 @@ import os from typing import Any, cast +import httpx import pytest from gradient import Gradient, AsyncGradient @@ -140,6 +141,65 @@ def test_path_params_retrieve(self, client: Gradient) -> None: "", ) + @parametrize + def test_method_wait_until_database_online(self, client: Gradient, respx_mock: Any) -> None: + """Test successful wait_until_database_online when database becomes ONLINE.""" + kb_uuid = "test-kb-id" + + call_count = [0] + + def get_response(_: httpx.Request) -> httpx.Response: + call_count[0] += 1 + if call_count[0] == 1: + return httpx.Response(200, json={"database_status": "CREATING"}) + else: + return httpx.Response(200, json={"database_status": "ONLINE"}) + + respx_mock.get(f"/v2/gen-ai/knowledge_bases/{kb_uuid}").mock(side_effect=get_response) + + kb = client.knowledge_bases.wait_until_database_online(kb_uuid, poll_interval=0.1, timeout=10.0) + assert_matches_type(KnowledgeBaseRetrieveResponse, kb, path=["response"]) + assert kb.database_status == "ONLINE" + + @parametrize + def test_wait_until_database_online_timeout(self, client: Gradient, respx_mock: Any) -> None: + """Test that wait_until_database_online raises timeout error.""" + from gradient._exceptions import KnowledgeBaseDatabaseTimeoutError + + kb_uuid = "test-kb-id" + + respx_mock.get(f"/v2/gen-ai/knowledge_bases/{kb_uuid}").mock( + return_value=httpx.Response(200, json={"database_status": "CREATING"}) + ) + + with pytest.raises(KnowledgeBaseDatabaseTimeoutError) as exc_info: + client.knowledge_bases.wait_until_database_online(kb_uuid, poll_interval=0.1, timeout=0.5) + + assert "did not reach ONLINE within" in str(exc_info.value) + assert exc_info.value.knowledge_base_id == kb_uuid + + @parametrize + def test_wait_until_database_online_failed(self, client: Gradient, respx_mock: Any) -> None: + """Test that wait_until_database_online raises error on failure status.""" + from gradient._exceptions import KnowledgeBaseDatabaseError + + kb_uuid = "test-kb-id" + + respx_mock.get(f"/v2/gen-ai/knowledge_bases/{kb_uuid}").mock( + return_value=httpx.Response(200, json={"database_status": "UNHEALTHY"}) + ) + + with pytest.raises(KnowledgeBaseDatabaseError) as exc_info: + client.knowledge_bases.wait_until_database_online(kb_uuid, poll_interval=0.1, timeout=10.0) + + assert "failed with status: UNHEALTHY" in str(exc_info.value) + + @parametrize + def test_wait_until_database_online_empty_uuid(self, client: Gradient) -> None: + """Test that wait_until_database_online validates empty uuid.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid`"): + client.knowledge_bases.wait_until_database_online("") + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize def test_method_update(self, client: Gradient) -> None: @@ -398,6 +458,65 @@ async def test_path_params_retrieve(self, async_client: AsyncGradient) -> None: "", ) + @parametrize + async def test_method_wait_until_database_online(self, async_client: AsyncGradient, respx_mock: Any) -> None: + """Async: Test successful wait_until_database_online when database becomes ONLINE.""" + kb_uuid = "test-kb-id" + + call_count = [0] + + def get_response(_: httpx.Request) -> httpx.Response: + call_count[0] += 1 + if call_count[0] == 1: + return httpx.Response(200, json={"database_status": "CREATING"}) + else: + return httpx.Response(200, json={"database_status": "ONLINE"}) + + respx_mock.get(f"/v2/gen-ai/knowledge_bases/{kb_uuid}").mock(side_effect=get_response) + + kb = await async_client.knowledge_bases.wait_until_database_online(kb_uuid, poll_interval=0.1, timeout=10.0) + assert_matches_type(KnowledgeBaseRetrieveResponse, kb, path=["response"]) + assert kb.database_status == "ONLINE" + + @parametrize + async def test_wait_until_database_online_timeout(self, async_client: AsyncGradient, respx_mock: Any) -> None: + """Async: Test that wait_until_database_online raises timeout error.""" + from gradient._exceptions import KnowledgeBaseDatabaseTimeoutError + + kb_uuid = "test-kb-id" + + respx_mock.get(f"/v2/gen-ai/knowledge_bases/{kb_uuid}").mock( + return_value=httpx.Response(200, json={"database_status": "CREATING"}) + ) + + with pytest.raises(KnowledgeBaseDatabaseTimeoutError) as exc_info: + await async_client.knowledge_bases.wait_until_database_online(kb_uuid, poll_interval=0.1, timeout=0.5) + + assert "did not reach ONLINE within" in str(exc_info.value) + assert exc_info.value.knowledge_base_id == kb_uuid + + @parametrize + async def test_wait_until_database_online_failed(self, async_client: AsyncGradient, respx_mock: Any) -> None: + """Async: Test that wait_until_database_online raises error on failure status.""" + from gradient._exceptions import KnowledgeBaseDatabaseError + + kb_uuid = "test-kb-id" + + respx_mock.get(f"/v2/gen-ai/knowledge_bases/{kb_uuid}").mock( + return_value=httpx.Response(200, json={"database_status": "UNHEALTHY"}) + ) + + with pytest.raises(KnowledgeBaseDatabaseError) as exc_info: + await async_client.knowledge_bases.wait_until_database_online(kb_uuid, poll_interval=0.1, timeout=10.0) + + assert "failed with status: UNHEALTHY" in str(exc_info.value) + + @parametrize + async def test_wait_until_database_online_empty_uuid(self, async_client: AsyncGradient) -> None: + """Async: Test that wait_until_database_online validates empty uuid.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid`"): + await async_client.knowledge_bases.wait_until_database_online("") + @pytest.mark.skip(reason="Prism tests are disabled") @parametrize async def test_method_update(self, async_client: AsyncGradient) -> None: