From 27d28dea55026213e8ffd919b8d39ef061c6bb38 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 8 Oct 2025 01:26:30 +0530 Subject: [PATCH 1/4] Add wait_for_database helper function to poll for Knowledge Base database creation (#1) * Initial plan * Add wait_for_database helper for knowledge base polling * Add unit tests and README documentation for wait_for_database * Fix linting and type checking for wait_for_database implementation --- README.md | 43 ++++ pyproject.toml | 2 +- .../resources/knowledge_bases/__init__.py | 2 + .../knowledge_bases/knowledge_bases.py | 181 ++++++++++++++++- tests/api_resources/test_knowledge_bases.py | 192 ++++++++++++++++++ 5 files changed, 418 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 30b7c75a..c5116430 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,49 @@ we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) to add `DIGITALOCEAN_ACCESS_TOKEN="My Access Token"`, `GRADIENT_MODEL_ACCESS_KEY="My Model Access Key"` to your `.env` file so that your keys are not stored in source control. +## Knowledge Base Database Polling + +When creating a Knowledge Base, the database deployment can take several minutes. The `wait_for_database()` helper function simplifies polling for the database status: + +```python +from gradient import Gradient +from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError +from gradient._exceptions import APITimeoutError + +client = Gradient() + +# Create a knowledge base +kb_response = client.knowledge_bases.create( + name="My Knowledge Base", + region="nyc1", + embedding_model_uuid="your-embedding-model-uuid", +) + +kb_uuid = kb_response.knowledge_base.uuid + +try: + # Wait for the database to be ready (default: 10 minute timeout, 5 second poll interval) + result = client.knowledge_bases.wait_for_database(kb_uuid) + print(f"Database status: {result.database_status}") # "ONLINE" + + # Custom timeout and poll interval + result = client.knowledge_bases.wait_for_database( + kb_uuid, + timeout=900.0, # 15 minutes + poll_interval=10.0 # Check every 10 seconds + ) + +except KnowledgeBaseDatabaseError as e: + # Database entered a failed state (DECOMMISSIONED or UNHEALTHY) + print(f"Database failed: {e}") + +except APITimeoutError: + # Database did not become ready within the timeout period + print("Timeout: Database did not become ready in time") +``` + +The helper handles all state transitions and will raise appropriate exceptions for failed states or timeouts. + ## Async usage Simply import `AsyncGradient` instead of `Gradient` and use `await` with each API call: diff --git a/pyproject.toml b/pyproject.toml index dade45c8..13bc2865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -246,5 +246,5 @@ known-first-party = ["gradient", "tests"] [tool.ruff.lint.per-file-ignores] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] -"tests/**.py" = ["T201", "T203"] +"tests/**.py" = ["T201", "T203", "ARG001"] "examples/**.py" = ["T201", "T203"] diff --git a/src/gradient/resources/knowledge_bases/__init__.py b/src/gradient/resources/knowledge_bases/__init__.py index 80d04328..90ebea00 100644 --- a/src/gradient/resources/knowledge_bases/__init__.py +++ b/src/gradient/resources/knowledge_bases/__init__.py @@ -19,6 +19,7 @@ from .knowledge_bases import ( KnowledgeBasesResource, AsyncKnowledgeBasesResource, + KnowledgeBaseDatabaseError, KnowledgeBasesResourceWithRawResponse, AsyncKnowledgeBasesResourceWithRawResponse, KnowledgeBasesResourceWithStreamingResponse, @@ -40,6 +41,7 @@ "AsyncIndexingJobsResourceWithStreamingResponse", "KnowledgeBasesResource", "AsyncKnowledgeBasesResource", + "KnowledgeBaseDatabaseError", "KnowledgeBasesResourceWithRawResponse", "AsyncKnowledgeBasesResourceWithRawResponse", "KnowledgeBasesResourceWithStreamingResponse", diff --git a/src/gradient/resources/knowledge_bases/knowledge_bases.py b/src/gradient/resources/knowledge_bases/knowledge_bases.py index 00fa0659..d92622a8 100644 --- a/src/gradient/resources/knowledge_bases/knowledge_bases.py +++ b/src/gradient/resources/knowledge_bases/knowledge_bases.py @@ -2,6 +2,8 @@ from __future__ import annotations +import time +import asyncio from typing import Iterable import httpx @@ -25,6 +27,7 @@ DataSourcesResourceWithStreamingResponse, AsyncDataSourcesResourceWithStreamingResponse, ) +from ..._exceptions import APITimeoutError from .indexing_jobs import ( IndexingJobsResource, AsyncIndexingJobsResource, @@ -40,7 +43,13 @@ from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse -__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource"] +__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource", "KnowledgeBaseDatabaseError"] + + +class KnowledgeBaseDatabaseError(Exception): + """Raised when a knowledge base database enters a failed state.""" + + pass class KnowledgeBasesResource(SyncAPIResource): @@ -330,6 +339,85 @@ def delete( cast_to=KnowledgeBaseDeleteResponse, ) + def wait_for_database( + self, + uuid: str, + *, + timeout: float = 600.0, + poll_interval: float = 5.0, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """ + Poll the knowledge base until the database status is ONLINE or a failed state is reached. + + This helper function repeatedly calls retrieve() to check the database_status field. + It will wait for the database to become ONLINE, or raise an exception if it enters + a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded. + + Args: + uuid: The knowledge base UUID to poll + + timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes) + + poll_interval: Time to wait between polls in seconds (default: 5 seconds) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + Returns: + The final KnowledgeBaseRetrieveResponse when the database status is ONLINE + + Raises: + KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) + + APITimeoutError: If the timeout is exceeded before the database becomes ONLINE + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + failed_states = {"DECOMMISSIONED", "UNHEALTHY"} + + while True: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise APITimeoutError( + request=httpx.Request( + method="GET", + url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}", + ) + ) + + response = self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + ) + + status = response.database_status + + if status == "ONLINE": + return response + + if status in failed_states: + raise KnowledgeBaseDatabaseError( + f"Knowledge base database entered failed state: {status}" + ) + + # Sleep before next poll, but don't exceed timeout + remaining_time = timeout - elapsed + sleep_time = min(poll_interval, remaining_time) + if sleep_time > 0: + time.sleep(sleep_time) + class AsyncKnowledgeBasesResource(AsyncAPIResource): @cached_property @@ -618,6 +706,85 @@ async def delete( cast_to=KnowledgeBaseDeleteResponse, ) + async def wait_for_database( + self, + uuid: str, + *, + timeout: float = 600.0, + poll_interval: float = 5.0, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + ) -> KnowledgeBaseRetrieveResponse: + """ + Poll the knowledge base until the database status is ONLINE or a failed state is reached. + + This helper function repeatedly calls retrieve() to check the database_status field. + It will wait for the database to become ONLINE, or raise an exception if it enters + a failed state (DECOMMISSIONED or UNHEALTHY) or if the timeout is exceeded. + + Args: + uuid: The knowledge base UUID to poll + + timeout: Maximum time to wait in seconds (default: 600 seconds / 10 minutes) + + poll_interval: Time to wait between polls in seconds (default: 5 seconds) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + Returns: + The final KnowledgeBaseRetrieveResponse when the database status is ONLINE + + Raises: + KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) + + APITimeoutError: If the timeout is exceeded before the database becomes ONLINE + """ + if not uuid: + raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") + + start_time = time.time() + failed_states = {"DECOMMISSIONED", "UNHEALTHY"} + + while True: + elapsed = time.time() - start_time + if elapsed >= timeout: + raise APITimeoutError( + request=httpx.Request( + method="GET", + url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}", + ) + ) + + response = await self.retrieve( + uuid, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + ) + + status = response.database_status + + if status == "ONLINE": + return response + + if status in failed_states: + raise KnowledgeBaseDatabaseError( + f"Knowledge base database entered failed state: {status}" + ) + + # Sleep before next poll, but don't exceed timeout + remaining_time = timeout - elapsed + sleep_time = min(poll_interval, remaining_time) + if sleep_time > 0: + await asyncio.sleep(sleep_time) + class KnowledgeBasesResourceWithRawResponse: def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: @@ -638,6 +805,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = to_raw_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithRawResponse: @@ -667,6 +837,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_raw_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = async_to_raw_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithRawResponse: @@ -696,6 +869,9 @@ def __init__(self, knowledge_bases: KnowledgeBasesResource) -> None: self.delete = to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = to_streamed_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> DataSourcesResourceWithStreamingResponse: @@ -725,6 +901,9 @@ def __init__(self, knowledge_bases: AsyncKnowledgeBasesResource) -> None: self.delete = async_to_streamed_response_wrapper( knowledge_bases.delete, ) + self.wait_for_database = async_to_streamed_response_wrapper( + knowledge_bases.wait_for_database, + ) @cached_property def data_sources(self) -> AsyncDataSourcesResourceWithStreamingResponse: diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 62965775..16773a0b 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -275,6 +275,102 @@ def test_path_params_delete(self, client: Gradient) -> None: "", ) + @parametrize + def test_method_wait_for_database_success(self, client: Gradient) -> None: + """Test wait_for_database with successful database status transition.""" + from unittest.mock import Mock + + call_count = [0] + + def mock_retrieve(uuid, **kwargs): + call_count[0] += 1 + response = Mock() + # Simulate CREATING -> ONLINE transition + response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + result = client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + assert result.database_status == "ONLINE" + assert call_count[0] == 2 + + @parametrize + def test_method_wait_for_database_failed_state(self, client: Gradient) -> None: + """Test wait_for_database with failed database status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "UNHEALTHY" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + def test_method_wait_for_database_timeout(self, client: Gradient) -> None: + """Test wait_for_database with timeout.""" + from unittest.mock import Mock + + from gradient._exceptions import APITimeoutError + + def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "CREATING" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(APITimeoutError): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=0.3, + poll_interval=0.1, + ) + + @parametrize + def test_method_wait_for_database_decommissioned(self, client: Gradient) -> None: + """Test wait_for_database with DECOMMISSIONED status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "DECOMMISSIONED" + return response + + client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): + client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + def test_path_params_wait_for_database(self, client: Gradient) -> None: + """Test wait_for_database validates uuid parameter.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid` but received ''"): + client.knowledge_bases.wait_for_database( + "", + ) + class TestAsyncKnowledgeBases: parametrize = pytest.mark.parametrize( @@ -532,3 +628,99 @@ async def test_path_params_delete(self, async_client: AsyncGradient) -> None: await async_client.knowledge_bases.with_raw_response.delete( "", ) + + @parametrize + async def test_method_wait_for_database_success(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with successful database status transition.""" + from unittest.mock import Mock + + call_count = [0] + + async def mock_retrieve(uuid, **kwargs): + call_count[0] += 1 + response = Mock() + # Simulate CREATING -> ONLINE transition + response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + result = await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + assert result.database_status == "ONLINE" + assert call_count[0] == 2 + + @parametrize + async def test_method_wait_for_database_failed_state(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with failed database status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + async def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "UNHEALTHY" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + async def test_method_wait_for_database_timeout(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with timeout.""" + from unittest.mock import Mock + + from gradient._exceptions import APITimeoutError + + async def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "CREATING" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(APITimeoutError): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=0.3, + poll_interval=0.1, + ) + + @parametrize + async def test_method_wait_for_database_decommissioned(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database with DECOMMISSIONED status.""" + from unittest.mock import Mock + + from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError + + async def mock_retrieve(uuid, **kwargs): + response = Mock() + response.database_status = "DECOMMISSIONED" + return response + + async_client.knowledge_bases.retrieve = mock_retrieve + + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): + await async_client.knowledge_bases.wait_for_database( + "test-uuid", + timeout=10.0, + poll_interval=0.1, + ) + + @parametrize + async def test_path_params_wait_for_database(self, async_client: AsyncGradient) -> None: + """Test async wait_for_database validates uuid parameter.""" + with pytest.raises(ValueError, match=r"Expected a non-empty value for `uuid` but received ''"): + await async_client.knowledge_bases.wait_for_database( + "", + ) From 5e8e5b1b8e2e197214df17d547712e5519983a00 Mon Sep 17 00:00:00 2001 From: kashyapdayal Date: Tue, 21 Oct 2025 19:34:43 +0530 Subject: [PATCH 2/4] Address review feedback and add KnowledgeBaseTimeoutError --- README.md | 43 ------------- examples/wait_for_knowledge_base.py | 60 +++++++++++++++++ pyproject.toml | 1 - .../resources/knowledge_bases/__init__.py | 4 +- .../knowledge_bases/knowledge_bases.py | 42 ++++++------ tests/api_resources/test_knowledge_bases.py | 64 +++++++++---------- tests/test_client.py | 8 +-- tests/test_files.py | 10 +-- 8 files changed, 126 insertions(+), 106 deletions(-) create mode 100644 examples/wait_for_knowledge_base.py diff --git a/README.md b/README.md index 7a456619..c9186c03 100644 --- a/README.md +++ b/README.md @@ -96,49 +96,6 @@ we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) to add `DIGITALOCEAN_ACCESS_TOKEN="My Access Token"`, `GRADIENT_MODEL_ACCESS_KEY="My Model Access Key"` to your `.env` file so that your keys are not stored in source control. -## Knowledge Base Database Polling - -When creating a Knowledge Base, the database deployment can take several minutes. The `wait_for_database()` helper function simplifies polling for the database status: - -```python -from gradient import Gradient -from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError -from gradient._exceptions import APITimeoutError - -client = Gradient() - -# Create a knowledge base -kb_response = client.knowledge_bases.create( - name="My Knowledge Base", - region="nyc1", - embedding_model_uuid="your-embedding-model-uuid", -) - -kb_uuid = kb_response.knowledge_base.uuid - -try: - # Wait for the database to be ready (default: 10 minute timeout, 5 second poll interval) - result = client.knowledge_bases.wait_for_database(kb_uuid) - print(f"Database status: {result.database_status}") # "ONLINE" - - # Custom timeout and poll interval - result = client.knowledge_bases.wait_for_database( - kb_uuid, - timeout=900.0, # 15 minutes - poll_interval=10.0 # Check every 10 seconds - ) - -except KnowledgeBaseDatabaseError as e: - # Database entered a failed state (DECOMMISSIONED or UNHEALTHY) - print(f"Database failed: {e}") - -except APITimeoutError: - # Database did not become ready within the timeout period - print("Timeout: Database did not become ready in time") -``` - -The helper handles all state transitions and will raise appropriate exceptions for failed states or timeouts. - ## Async usage Simply import `AsyncGradient` instead of `Gradient` and use `await` with each API call: diff --git a/examples/wait_for_knowledge_base.py b/examples/wait_for_knowledge_base.py new file mode 100644 index 00000000..ee503674 --- /dev/null +++ b/examples/wait_for_knowledge_base.py @@ -0,0 +1,60 @@ +""" +Example demonstrating how to use the wait_for_database helper function. + +This example shows how to: +1. Create a knowledge base +2. Wait for its database to be ready +3. Handle errors and timeouts appropriately +""" + +import os + +from gradient import Gradient +from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError, KnowledgeBaseDatabaseError + + +def main() -> None: + """Create a knowledge base and wait for its database to be ready.""" + # Initialize the Gradient client + # Note: DIGITALOCEAN_ACCESS_TOKEN must be set in your environment + client = Gradient( + access_token=os.environ.get("DIGITALOCEAN_ACCESS_TOKEN"), + ) + + # Create a knowledge base + # Replace these values with your actual configuration + kb_response = client.knowledge_bases.create( + name="My Knowledge Base", + region="nyc1", # Choose your preferred region + embedding_model_uuid="your-embedding-model-uuid", # Use your embedding model UUID + ) + + kb_uuid = kb_response.knowledge_base.uuid + print(f"Created knowledge base: {kb_uuid}") + + try: + # Wait for the database to be ready + # Default: 10 minute timeout, 5 second poll interval + print("Waiting for database to be ready...") + result = client.knowledge_bases.wait_for_database(kb_uuid) + print(f"Database status: {result.database_status}") # "ONLINE" + print("Knowledge base is ready!") + + # Alternative: Custom timeout and poll interval + # result = client.knowledge_bases.wait_for_database( + # kb_uuid, + # timeout=900.0, # 15 minutes + # poll_interval=10.0 # Check every 10 seconds + # ) + + except KnowledgeBaseDatabaseError as e: + # Database entered a failed state (DECOMMISSIONED or UNHEALTHY) + print(f"Database failed: {e}") + + except KnowledgeBaseTimeoutError as e: + # Database did not become ready within the timeout period + print(f"Timeout: {e}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 683b9d79..0e83a25b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -253,5 +253,4 @@ known-first-party = ["gradient", "tests"] [tool.ruff.lint.per-file-ignores] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] -"tests/**.py" = ["T201", "T203", "ARG001"] "examples/**.py" = ["T201", "T203"] diff --git a/src/gradient/resources/knowledge_bases/__init__.py b/src/gradient/resources/knowledge_bases/__init__.py index 90ebea00..353dc05c 100644 --- a/src/gradient/resources/knowledge_bases/__init__.py +++ b/src/gradient/resources/knowledge_bases/__init__.py @@ -18,8 +18,9 @@ ) from .knowledge_bases import ( KnowledgeBasesResource, - AsyncKnowledgeBasesResource, + KnowledgeBaseTimeoutError, KnowledgeBaseDatabaseError, + AsyncKnowledgeBasesResource, KnowledgeBasesResourceWithRawResponse, AsyncKnowledgeBasesResourceWithRawResponse, KnowledgeBasesResourceWithStreamingResponse, @@ -42,6 +43,7 @@ "KnowledgeBasesResource", "AsyncKnowledgeBasesResource", "KnowledgeBaseDatabaseError", + "KnowledgeBaseTimeoutError", "KnowledgeBasesResourceWithRawResponse", "AsyncKnowledgeBasesResourceWithRawResponse", "KnowledgeBasesResourceWithStreamingResponse", diff --git a/src/gradient/resources/knowledge_bases/knowledge_bases.py b/src/gradient/resources/knowledge_bases/knowledge_bases.py index d92622a8..4325148c 100644 --- a/src/gradient/resources/knowledge_bases/knowledge_bases.py +++ b/src/gradient/resources/knowledge_bases/knowledge_bases.py @@ -27,7 +27,6 @@ DataSourcesResourceWithStreamingResponse, AsyncDataSourcesResourceWithStreamingResponse, ) -from ..._exceptions import APITimeoutError from .indexing_jobs import ( IndexingJobsResource, AsyncIndexingJobsResource, @@ -43,7 +42,12 @@ from ...types.knowledge_base_update_response import KnowledgeBaseUpdateResponse from ...types.knowledge_base_retrieve_response import KnowledgeBaseRetrieveResponse -__all__ = ["KnowledgeBasesResource", "AsyncKnowledgeBasesResource", "KnowledgeBaseDatabaseError"] +__all__ = [ + "KnowledgeBasesResource", + "AsyncKnowledgeBasesResource", + "KnowledgeBaseDatabaseError", + "KnowledgeBaseTimeoutError", +] class KnowledgeBaseDatabaseError(Exception): @@ -52,6 +56,12 @@ class KnowledgeBaseDatabaseError(Exception): pass +class KnowledgeBaseTimeoutError(Exception): + """Raised when waiting for a knowledge base database times out.""" + + pass + + class KnowledgeBasesResource(SyncAPIResource): @cached_property def data_sources(self) -> DataSourcesResource: @@ -377,7 +387,7 @@ def wait_for_database( Raises: KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) - APITimeoutError: If the timeout is exceeded before the database becomes ONLINE + KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE """ if not uuid: raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") @@ -388,11 +398,9 @@ def wait_for_database( while True: elapsed = time.time() - start_time if elapsed >= timeout: - raise APITimeoutError( - request=httpx.Request( - method="GET", - url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}", - ) + raise KnowledgeBaseTimeoutError( + f"Timeout waiting for knowledge base database to become ready. " + f"Database did not reach ONLINE status within {timeout} seconds." ) response = self.retrieve( @@ -408,9 +416,7 @@ def wait_for_database( return response if status in failed_states: - raise KnowledgeBaseDatabaseError( - f"Knowledge base database entered failed state: {status}" - ) + raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}") # Sleep before next poll, but don't exceed timeout remaining_time = timeout - elapsed @@ -744,7 +750,7 @@ async def wait_for_database( Raises: KnowledgeBaseDatabaseError: If the database enters a failed state (DECOMMISSIONED, UNHEALTHY) - APITimeoutError: If the timeout is exceeded before the database becomes ONLINE + KnowledgeBaseTimeoutError: If the timeout is exceeded before the database becomes ONLINE """ if not uuid: raise ValueError(f"Expected a non-empty value for `uuid` but received {uuid!r}") @@ -755,11 +761,9 @@ async def wait_for_database( while True: elapsed = time.time() - start_time if elapsed >= timeout: - raise APITimeoutError( - request=httpx.Request( - method="GET", - url=f"https://api.digitalocean.com/v2/gen-ai/knowledge_bases/{uuid}", - ) + raise KnowledgeBaseTimeoutError( + f"Timeout waiting for knowledge base database to become ready. " + f"Database did not reach ONLINE status within {timeout} seconds." ) response = await self.retrieve( @@ -775,9 +779,7 @@ async def wait_for_database( return response if status in failed_states: - raise KnowledgeBaseDatabaseError( - f"Knowledge base database entered failed state: {status}" - ) + raise KnowledgeBaseDatabaseError(f"Knowledge base database entered failed state: {status}") # Sleep before next poll, but don't exceed timeout remaining_time = timeout - elapsed diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 16773a0b..1b0cc1a8 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -279,24 +279,24 @@ def test_path_params_delete(self, client: Gradient) -> None: def test_method_wait_for_database_success(self, client: Gradient) -> None: """Test wait_for_database with successful database status transition.""" from unittest.mock import Mock - + call_count = [0] - + def mock_retrieve(uuid, **kwargs): call_count[0] += 1 response = Mock() # Simulate CREATING -> ONLINE transition response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" return response - + client.knowledge_bases.retrieve = mock_retrieve - + result = client.knowledge_bases.wait_for_database( "test-uuid", timeout=10.0, poll_interval=0.1, ) - + assert result.database_status == "ONLINE" assert call_count[0] == 2 @@ -306,14 +306,14 @@ def test_method_wait_for_database_failed_state(self, client: Gradient) -> None: from unittest.mock import Mock from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - + def mock_retrieve(uuid, **kwargs): response = Mock() response.database_status = "UNHEALTHY" return response - + client.knowledge_bases.retrieve = mock_retrieve - + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): client.knowledge_bases.wait_for_database( "test-uuid", @@ -326,16 +326,16 @@ def test_method_wait_for_database_timeout(self, client: Gradient) -> None: """Test wait_for_database with timeout.""" from unittest.mock import Mock - from gradient._exceptions import APITimeoutError - + from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError + def mock_retrieve(uuid, **kwargs): response = Mock() response.database_status = "CREATING" return response - + client.knowledge_bases.retrieve = mock_retrieve - - with pytest.raises(APITimeoutError): + + with pytest.raises(KnowledgeBaseTimeoutError): client.knowledge_bases.wait_for_database( "test-uuid", timeout=0.3, @@ -348,14 +348,14 @@ def test_method_wait_for_database_decommissioned(self, client: Gradient) -> None from unittest.mock import Mock from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - + def mock_retrieve(uuid, **kwargs): response = Mock() response.database_status = "DECOMMISSIONED" return response - + client.knowledge_bases.retrieve = mock_retrieve - + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): client.knowledge_bases.wait_for_database( "test-uuid", @@ -633,24 +633,24 @@ async def test_path_params_delete(self, async_client: AsyncGradient) -> None: async def test_method_wait_for_database_success(self, async_client: AsyncGradient) -> None: """Test async wait_for_database with successful database status transition.""" from unittest.mock import Mock - + call_count = [0] - + async def mock_retrieve(uuid, **kwargs): call_count[0] += 1 response = Mock() # Simulate CREATING -> ONLINE transition response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" return response - + async_client.knowledge_bases.retrieve = mock_retrieve - + result = await async_client.knowledge_bases.wait_for_database( "test-uuid", timeout=10.0, poll_interval=0.1, ) - + assert result.database_status == "ONLINE" assert call_count[0] == 2 @@ -660,14 +660,14 @@ async def test_method_wait_for_database_failed_state(self, async_client: AsyncGr from unittest.mock import Mock from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - + async def mock_retrieve(uuid, **kwargs): response = Mock() response.database_status = "UNHEALTHY" return response - + async_client.knowledge_bases.retrieve = mock_retrieve - + with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): await async_client.knowledge_bases.wait_for_database( "test-uuid", @@ -680,16 +680,16 @@ async def test_method_wait_for_database_timeout(self, async_client: AsyncGradien """Test async wait_for_database with timeout.""" from unittest.mock import Mock - from gradient._exceptions import APITimeoutError - + from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError + async def mock_retrieve(uuid, **kwargs): response = Mock() response.database_status = "CREATING" return response - + async_client.knowledge_bases.retrieve = mock_retrieve - - with pytest.raises(APITimeoutError): + + with pytest.raises(KnowledgeBaseTimeoutError): await async_client.knowledge_bases.wait_for_database( "test-uuid", timeout=0.3, @@ -702,14 +702,14 @@ async def test_method_wait_for_database_decommissioned(self, async_client: Async from unittest.mock import Mock from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - + async def mock_retrieve(uuid, **kwargs): response = Mock() response.database_status = "DECOMMISSIONED" return response - + async_client.knowledge_bases.retrieve = mock_retrieve - + with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): await async_client.knowledge_bases.wait_for_database( "test-uuid", diff --git a/tests/test_client.py b/tests/test_client.py index ddf1c4db..846c0bb6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -287,9 +287,9 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic add_leak(leaks, diff) if leaks: for leak in leaks: - print("MEMORY LEAK:", leak) + print("MEMORY LEAK:", leak) # noqa: T201 for frame in leak.traceback: - print(frame) + print(frame) # noqa: T201 raise AssertionError() def test_request_timeout(self) -> None: @@ -1304,9 +1304,9 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic add_leak(leaks, diff) if leaks: for leak in leaks: - print("MEMORY LEAK:", leak) + print("MEMORY LEAK:", leak) # noqa: T201 for frame in leak.traceback: - print(frame) + print(frame) # noqa: T201 raise AssertionError() async def test_request_timeout(self) -> None: diff --git a/tests/test_files.py b/tests/test_files.py index 4d9f4066..54210e83 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -11,34 +11,34 @@ def test_pathlib_includes_file_name() -> None: result = to_httpx_files({"file": readme_path}) - print(result) + print(result) # noqa: T201 assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) def test_tuple_input() -> None: result = to_httpx_files([("file", readme_path)]) - print(result) + print(result) # noqa: T201 assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) @pytest.mark.asyncio async def test_async_pathlib_includes_file_name() -> None: result = await async_to_httpx_files({"file": readme_path}) - print(result) + print(result) # noqa: T201 assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) @pytest.mark.asyncio async def test_async_supports_anyio_path() -> None: result = await async_to_httpx_files({"file": anyio.Path(readme_path)}) - print(result) + print(result) # noqa: T201 assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) @pytest.mark.asyncio async def test_async_tuple_input() -> None: result = await async_to_httpx_files([("file", readme_path)]) - print(result) + print(result) # noqa: T201 assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) From 9daa957fa0a9d82e9f4e0ffdaab98658bcb0e364 Mon Sep 17 00:00:00 2001 From: kashyapdayal Date: Wed, 22 Oct 2025 22:44:25 +0530 Subject: [PATCH 3/4] Resolve linting and formatting errors --- tests/api_resources/test_knowledge_bases.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 1b0cc1a8..3b84a4d2 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -282,7 +282,7 @@ def test_method_wait_for_database_success(self, client: Gradient) -> None: call_count = [0] - def mock_retrieve(uuid, **kwargs): + def mock_retrieve(_uuid, **_kwargs): call_count[0] += 1 response = Mock() # Simulate CREATING -> ONLINE transition @@ -307,7 +307,7 @@ def test_method_wait_for_database_failed_state(self, client: Gradient) -> None: from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - def mock_retrieve(uuid, **kwargs): + def mock_retrieve(_uuid, **_kwargs): response = Mock() response.database_status = "UNHEALTHY" return response @@ -328,7 +328,7 @@ def test_method_wait_for_database_timeout(self, client: Gradient) -> None: from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError - def mock_retrieve(uuid, **kwargs): + def mock_retrieve(_uuid, **_kwargs): response = Mock() response.database_status = "CREATING" return response @@ -349,7 +349,7 @@ def test_method_wait_for_database_decommissioned(self, client: Gradient) -> None from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - def mock_retrieve(uuid, **kwargs): + def mock_retrieve(_uuid, **_kwargs): response = Mock() response.database_status = "DECOMMISSIONED" return response @@ -636,7 +636,7 @@ async def test_method_wait_for_database_success(self, async_client: AsyncGradien call_count = [0] - async def mock_retrieve(uuid, **kwargs): + async def mock_retrieve(_uuid, **_kwargs): call_count[0] += 1 response = Mock() # Simulate CREATING -> ONLINE transition @@ -661,7 +661,7 @@ async def test_method_wait_for_database_failed_state(self, async_client: AsyncGr from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - async def mock_retrieve(uuid, **kwargs): + async def mock_retrieve(_uuid, **_kwargs): response = Mock() response.database_status = "UNHEALTHY" return response @@ -682,7 +682,7 @@ async def test_method_wait_for_database_timeout(self, async_client: AsyncGradien from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError - async def mock_retrieve(uuid, **kwargs): + async def mock_retrieve(_uuid, **_kwargs): response = Mock() response.database_status = "CREATING" return response @@ -703,7 +703,7 @@ async def test_method_wait_for_database_decommissioned(self, async_client: Async from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - async def mock_retrieve(uuid, **kwargs): + async def mock_retrieve(_uuid, **_kwargs): response = Mock() response.database_status = "DECOMMISSIONED" return response From 0b1f3cf0e671994d32fc3c865f9eb613f234dedb Mon Sep 17 00:00:00 2001 From: Navaneeth K Date: Thu, 23 Oct 2025 16:28:52 +0530 Subject: [PATCH 4/4] Resolve type checking errors in tests and examples --- examples/wait_for_knowledge_base.py | 4 +++ tests/api_resources/test_knowledge_bases.py | 32 ++++++++++----------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/examples/wait_for_knowledge_base.py b/examples/wait_for_knowledge_base.py index ee503674..739ff80e 100644 --- a/examples/wait_for_knowledge_base.py +++ b/examples/wait_for_knowledge_base.py @@ -29,6 +29,10 @@ def main() -> None: embedding_model_uuid="your-embedding-model-uuid", # Use your embedding model UUID ) + if not kb_response.knowledge_base or not kb_response.knowledge_base.uuid: + print("Failed to create knowledge base") + return + kb_uuid = kb_response.knowledge_base.uuid print(f"Created knowledge base: {kb_uuid}") diff --git a/tests/api_resources/test_knowledge_bases.py b/tests/api_resources/test_knowledge_bases.py index 3b84a4d2..a42277e4 100644 --- a/tests/api_resources/test_knowledge_bases.py +++ b/tests/api_resources/test_knowledge_bases.py @@ -282,14 +282,14 @@ def test_method_wait_for_database_success(self, client: Gradient) -> None: call_count = [0] - def mock_retrieve(_uuid, **_kwargs): + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 call_count[0] += 1 response = Mock() # Simulate CREATING -> ONLINE transition response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" return response - client.knowledge_bases.retrieve = mock_retrieve + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] result = client.knowledge_bases.wait_for_database( "test-uuid", @@ -307,12 +307,12 @@ def test_method_wait_for_database_failed_state(self, client: Gradient) -> None: from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - def mock_retrieve(_uuid, **_kwargs): + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 response = Mock() response.database_status = "UNHEALTHY" return response - client.knowledge_bases.retrieve = mock_retrieve + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): client.knowledge_bases.wait_for_database( @@ -328,12 +328,12 @@ def test_method_wait_for_database_timeout(self, client: Gradient) -> None: from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError - def mock_retrieve(_uuid, **_kwargs): + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 response = Mock() response.database_status = "CREATING" return response - client.knowledge_bases.retrieve = mock_retrieve + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] with pytest.raises(KnowledgeBaseTimeoutError): client.knowledge_bases.wait_for_database( @@ -349,12 +349,12 @@ def test_method_wait_for_database_decommissioned(self, client: Gradient) -> None from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - def mock_retrieve(_uuid, **_kwargs): + def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 response = Mock() response.database_status = "DECOMMISSIONED" return response - client.knowledge_bases.retrieve = mock_retrieve + client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): client.knowledge_bases.wait_for_database( @@ -636,14 +636,14 @@ async def test_method_wait_for_database_success(self, async_client: AsyncGradien call_count = [0] - async def mock_retrieve(_uuid, **_kwargs): + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 call_count[0] += 1 response = Mock() # Simulate CREATING -> ONLINE transition response.database_status = "CREATING" if call_count[0] == 1 else "ONLINE" return response - async_client.knowledge_bases.retrieve = mock_retrieve + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] result = await async_client.knowledge_bases.wait_for_database( "test-uuid", @@ -661,12 +661,12 @@ async def test_method_wait_for_database_failed_state(self, async_client: AsyncGr from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - async def mock_retrieve(_uuid, **_kwargs): + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 response = Mock() response.database_status = "UNHEALTHY" return response - async_client.knowledge_bases.retrieve = mock_retrieve + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] with pytest.raises(KnowledgeBaseDatabaseError, match="UNHEALTHY"): await async_client.knowledge_bases.wait_for_database( @@ -682,12 +682,12 @@ async def test_method_wait_for_database_timeout(self, async_client: AsyncGradien from gradient.resources.knowledge_bases import KnowledgeBaseTimeoutError - async def mock_retrieve(_uuid, **_kwargs): + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 response = Mock() response.database_status = "CREATING" return response - async_client.knowledge_bases.retrieve = mock_retrieve + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] with pytest.raises(KnowledgeBaseTimeoutError): await async_client.knowledge_bases.wait_for_database( @@ -703,12 +703,12 @@ async def test_method_wait_for_database_decommissioned(self, async_client: Async from gradient.resources.knowledge_bases import KnowledgeBaseDatabaseError - async def mock_retrieve(_uuid, **_kwargs): + async def mock_retrieve(uuid: str, **kwargs: object) -> Mock: # noqa: ARG001 response = Mock() response.database_status = "DECOMMISSIONED" return response - async_client.knowledge_bases.retrieve = mock_retrieve + async_client.knowledge_bases.retrieve = mock_retrieve # type: ignore[method-assign] with pytest.raises(KnowledgeBaseDatabaseError, match="DECOMMISSIONED"): await async_client.knowledge_bases.wait_for_database(