diff --git a/docs/README.md b/docs/README.md index 2b9871a..b3191db 100644 --- a/docs/README.md +++ b/docs/README.md @@ -163,8 +163,8 @@ trace_eval = client.trace_evaluations.create( judge_id=judge.id, ) -# Get results -results = client.trace_evaluations.get_results(trace_eval.id) +# Wait for completion and get results +result = client.trace_evaluations.wait_for_completion(trace_eval.id) ``` ### Custom models diff --git a/docs/api-reference/client.md b/docs/api-reference/client.md index a7e534e..8787787 100644 --- a/docs/api-reference/client.md +++ b/docs/api-reference/client.md @@ -33,13 +33,14 @@ client = AsyncStratix(api_key="your_api_key") ## Constructor Parameters -### `Stratix(api_key, base_url, timeout)` and `AsyncStratix(api_key, base_url, timeout)` +### `Stratix(api_key, base_url, timeout, max_retries)` and `AsyncStratix(api_key, base_url, timeout, max_retries)` -| Parameter | Type | Required | Default | Description | -| ---------- | -------------------------------- | -------- | ------------- | ----------------------------- | -| `api_key` | `str \| None` | Yes\* | `None` | Your LayerLens Stratix API key | -| `base_url` | `str \| httpx.URL \| None` | No | Stratix API URL | Custom API base URL | -| `timeout` | `float \| httpx.Timeout \| None` | No | 10 minutes | Request timeout configuration | +| Parameter | Type | Required | Default | Description | +| ------------- | -------------------------------- | -------- | ------------- | ----------------------------- | +| `api_key` | `str \| None` | Yes\* | `None` | Your LayerLens Stratix API key | +| `base_url` | `str \| httpx.URL \| None` | No | Stratix API URL | Custom API base URL | +| `timeout` | `float \| httpx.Timeout \| None` | No | 10 minutes | Request timeout configuration | +| `max_retries` | `int` | No | `2` | Maximum number of retries on retryable errors (429, 500, 502, 503, 504) | \*Required unless set via environment variables @@ -81,6 +82,23 @@ from layerlens import Stratix client = Stratix(timeout=30.0) ``` +### Retry Configuration + +The client automatically retries requests that fail with retryable status codes (429 Too Many Requests, 500, 502, 503, 504) using exponential backoff. If the server sends a `Retry-After` header, the client respects it. + +```python +from layerlens import Stratix + +# Default: 2 retries +client = Stratix() + +# More retries for batch-heavy workloads +client = Stratix(max_retries=5) + +# Disable retries entirely +client = Stratix(max_retries=0) +``` + ### Per-Request Timeout Override ```python diff --git a/docs/api-reference/judges.md b/docs/api-reference/judges.md index e7b4906..c15a76c 100644 --- a/docs/api-reference/judges.md +++ b/docs/api-reference/judges.md @@ -67,7 +67,7 @@ Creates a new judge with the specified evaluation criteria. | ----------------- | -------------------------------- | -------- | -------------------------------------------- | | `name` | `str` | Yes | Display name for the judge | | `evaluation_goal` | `str` | Yes | Description of what the judge should evaluate | -| `model_id` | `str \| None` | Yes* | ID of the LLM model to use (required by API)| +| `model_id` | `str \| None` | No | ID of the LLM model to use. If omitted, the server uses a default model | | `timeout` | `float \| httpx.Timeout \| None` | No | Override request timeout | #### Returns diff --git a/docs/api-reference/trace-evaluations.md b/docs/api-reference/trace-evaluations.md index 0e88b61..8cec9c8 100644 --- a/docs/api-reference/trace-evaluations.md +++ b/docs/api-reference/trace-evaluations.md @@ -26,9 +26,9 @@ evaluation = client.trace_evaluations.create( judge_id="judge-123", ) -# Get results -results = client.trace_evaluations.get_results(evaluation.id) -for result in results.results: +# Wait for completion and get results +result = client.trace_evaluations.wait_for_completion(evaluation.id) +if result: print(f"Score: {result.score}, Passed: {result.passed}") print(f"Reasoning: {result.reasoning}") ``` @@ -47,8 +47,8 @@ async def main(): judge_id="judge-123", ) - results = await client.trace_evaluations.get_results(evaluation.id) - for result in results.results: + result = await client.trace_evaluations.wait_for_completion(evaluation.id) + if result: print(f"Score: {result.score}, Passed: {result.passed}") if __name__ == "__main__": @@ -149,6 +149,8 @@ response = client.trace_evaluations.get_many( Retrieves the detailed results of a completed trace evaluation, including scores, reasoning, and step-by-step analysis. +Returns `None` if results are not yet available (evaluation still pending or in progress). + #### Parameters | Parameter | Type | Required | Description | @@ -158,23 +160,62 @@ Retrieves the detailed results of a completed trace evaluation, including scores #### Returns -Returns a `TraceEvaluationResultsResponse` object containing: +Returns a `TraceEvaluationResultsResponse` object with the evaluation result fields (score, passed, reasoning, etc.). -- `results`: List of `TraceEvaluationResult` objects +Returns `None` if the evaluation has not completed yet or if the request fails. -Returns `None` if the request fails. +#### Example + +```python +result = client.trace_evaluations.get_results("eval-123") +if result: + print(f"Score: {result.score}") + print(f"Passed: {result.passed}") + print(f"Reasoning: {result.reasoning}") + for step in result.steps: + print(f" Tool: {step.tool}, Result: {step.result}") +``` + +### `wait_for_completion(id, interval_seconds=3, timeout_seconds=300)` + +Polls the evaluation status until it reaches a terminal state (success or failure), then returns the results. This is the recommended way to wait for trace evaluation results. + +#### Parameters + +| Parameter | Type | Required | Default | Description | +| ------------------ | -------------- | -------- | ------- | ------------------------------------------------ | +| `id` | `str` | Yes | | The unique trace evaluation ID | +| `interval_seconds` | `int` | No | `3` | Seconds between status polls | +| `timeout_seconds` | `int \| None` | No | `300` | Maximum wait time. `None` waits indefinitely | + +#### Returns + +Returns a `TraceEvaluationResultsResponse` object if the evaluation completes successfully. + +Returns `None` if the evaluation failed or no results are available. + +Raises `TimeoutError` if `timeout_seconds` is exceeded. #### Example ```python -results_response = client.trace_evaluations.get_results("eval-123") -if results_response: - for result in results_response.results: - print(f"Score: {result.score}") - print(f"Passed: {result.passed}") - print(f"Reasoning: {result.reasoning}") - for step in result.steps: - print(f" Step {step.step}: {step.reasoning}") +evaluation = client.trace_evaluations.create( + trace_id="trace-abc", + judge_id="judge-xyz", +) + +# Wait up to 5 minutes for results +result = client.trace_evaluations.wait_for_completion(evaluation.id) +if result: + print(f"Score: {result.score}, Passed: {result.passed}") + print(f"Reasoning: {result.reasoning}") + +# Custom timeout and polling interval +result = client.trace_evaluations.wait_for_completion( + evaluation.id, + interval_seconds=5, + timeout_seconds=600, +) ``` ### `estimate_cost(trace_ids, judge_id, timeout=None)` diff --git a/docs/examples/judges-and-traces.md b/docs/examples/judges-and-traces.md index 9d5e80b..9288c42 100644 --- a/docs/examples/judges-and-traces.md +++ b/docs/examples/judges-and-traces.md @@ -103,14 +103,10 @@ from layerlens import Stratix client = Stratix() -# Fetch a model and create a judge -models = client.models.get(type="public", name="gpt-4o") -model = models[0] - +# Create a judge (no model_id → server uses default model) judge = client.judges.create( name=f"Trace Eval Demo Judge {int(time.time())}", evaluation_goal="Evaluate whether the response is accurate, complete, and well-structured", - model_id=model.id, ) print(f"Created judge {judge.id}: {judge.name}") @@ -133,28 +129,16 @@ evaluation = client.trace_evaluations.create( ) print(f"Created evaluation {evaluation.id}, status: {evaluation.status}") -# --- Wait for evaluation to complete -for _ in range(30): - evaluation = client.trace_evaluations.get(evaluation.id) - print(f"Evaluation status: {evaluation.status}") - if evaluation.status.value in ("success", "failure"): - break - time.sleep(2) - -# --- Get evaluation results -try: - results_response = client.trace_evaluations.get_results(evaluation.id) - if results_response and results_response.results: - for result in results_response.results: - print(f" Score: {result.score}, Passed: {result.passed}") - print(f" Reasoning: {result.reasoning}") - if result.steps: - for step in result.steps: - print(f" Step {step.step}: {step.reasoning}") - else: - print(" No results returned") -except Exception: - print(" No results yet (evaluation may still be in progress)") +# --- Wait for completion and get results +result = client.trace_evaluations.wait_for_completion(evaluation.id) +if result: + print(f" Score: {result.score}, Passed: {result.passed}") + print(f" Reasoning: {result.reasoning}") + if result.steps: + for step in result.steps: + print(f" Tool: {step.tool}, Result: {step.result[:80]}") +else: + print(" No results returned (evaluation may have failed)") # --- List all trace evaluations response = client.trace_evaluations.get_many() @@ -287,20 +271,17 @@ async def main(): if evaluation: print(f" Evaluation {evaluation.id}: {evaluation.status}") - # --- Wait and fetch results - await asyncio.sleep(10) - for evaluation in evaluations: - if not evaluation: - continue - try: - results_response = await client.trace_evaluations.get_results(evaluation.id) - if results_response and results_response.results: - for result in results_response.results: - print(f" Score: {result.score}, Passed: {result.passed}") - else: - print(f" Evaluation {evaluation.id}: no results yet") - except Exception: - print(f" Evaluation {evaluation.id}: results not available yet") + # --- Wait for results concurrently + result_tasks = [ + client.trace_evaluations.wait_for_completion(e.id) + for e in evaluations if e + ] + results = await asyncio.gather(*result_tasks) + for result in results: + if result: + print(f" Score: {result.score}, Passed: {result.passed}") + else: + print(f" No results (evaluation may have failed)") await client.judges.delete(judge.id) diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index f9f32b4..3e436b0 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -34,7 +34,6 @@ print(f"Accuracy: {result.accuracy}") ## Create a Judge and Evaluate Traces ```python -import time from layerlens import Stratix client = Stratix() @@ -55,15 +54,8 @@ trace_eval = client.trace_evaluations.create( judge_id=judge.id, ) -# Poll until complete -while True: - evaluation = client.trace_evaluations.get(trace_eval.id) - if evaluation.status.value in ("success", "failure"): - break - time.sleep(2) - -# Get results -result = client.trace_evaluations.get_results(trace_eval.id) +# Wait for completion and get results +result = client.trace_evaluations.wait_for_completion(trace_eval.id) if result: print(f"Score: {result.score}, Passed: {result.passed}") print(f"Reasoning: {result.reasoning}") diff --git a/examples/trace_evaluations.py b/examples/trace_evaluations.py index 0891f19..c1bbd18 100644 --- a/examples/trace_evaluations.py +++ b/examples/trace_evaluations.py @@ -7,19 +7,10 @@ # Construct sync client (API key from env or inline) client = Stratix() -# --- Fetch a model to use for judge creation -models = client.models.get(type="public", name="gpt-4o") -if not models: - print("No models found, exiting") - exit(1) -model = models[0] -print(f"Using model: {model.name} ({model.id})") - -# --- Create a judge to use for evaluations +# --- Create a judge (no model_id → server uses default model) judge = client.judges.create( name=f"Trace Eval Demo Judge {int(time.time())}", evaluation_goal="Evaluate whether the response is accurate, complete, and well-structured", - model_id=model.id, ) print(f"Created judge {judge.id}: {judge.name}") @@ -27,7 +18,6 @@ traces_response = client.traces.get_many(page_size=3) if not traces_response or len(traces_response.traces) == 0: print("No traces found. Upload some traces first using traces.py") - # Clean up the judge client.judges.delete(judge.id) exit(1) @@ -48,27 +38,16 @@ ) print(f"Created evaluation {evaluation.id}, status: {evaluation.status}") -# --- Wait for evaluation to complete (poll every 2 seconds, up to 60s) -for _ in range(30): - evaluation = client.trace_evaluations.get(evaluation.id) - print(f"Evaluation status: {evaluation.status}") - if evaluation.status.value in ("success", "failure"): - break - time.sleep(2) - -# --- Get evaluation results (may 404 if still in progress) -try: - result = client.trace_evaluations.get_results(evaluation.id) - if result: - print(f" Score: {result.score}, Passed: {result.passed}") - print(f" Reasoning: {result.reasoning}") - if result.steps: - for step in result.steps: - print(f" Tool: {step.tool}, Result: {step.result[:80]}") - else: - print(" No results returned") -except Exception: - print(" No results yet (evaluation may still be in progress)") +# --- Wait for completion and get results in one call +result = client.trace_evaluations.wait_for_completion(evaluation.id) +if result: + print(f" Score: {result.score}, Passed: {result.passed}") + print(f" Reasoning: {result.reasoning}") + if result.steps: + for step in result.steps: + print(f" Tool: {step.tool}, Result: {step.result[:80]}") +else: + print(" No results returned (evaluation may have failed)") # --- List all trace evaluations response = client.trace_evaluations.get_many() diff --git a/src/layerlens/_base_client.py b/src/layerlens/_base_client.py index 8cd7650..a43d2fe 100644 --- a/src/layerlens/_base_client.py +++ b/src/layerlens/_base_client.py @@ -24,14 +24,18 @@ class BaseClient(httpx.Client): + _max_retries: int + def __init__( self, *, base_url: URL | str, headers: Optional[Dict[str, str]] = None, timeout: Union[float, httpx.Timeout, None] = None, + max_retries: int = MAX_RETRIES, **kwargs: Any, ): + self._max_retries = max_retries super().__init__(base_url=base_url, headers=headers, timeout=timeout, **kwargs) @property @@ -58,7 +62,7 @@ def _request_cast( **kwargs: Any, ) -> Union[ResponseT, httpx.Response]: combined_headers = {**self.default_headers, **(headers or {})} - retries_left = MAX_RETRIES + retries_left = self._max_retries delay = INITIAL_RETRY_DELAY while True: @@ -168,14 +172,18 @@ def _make_status_error( class BaseAsyncClient(httpx.AsyncClient): + _max_retries: int + def __init__( self, *, base_url: URL | str, headers: Optional[Dict[str, str]] = None, timeout: Union[float, httpx.Timeout, None] = None, + max_retries: int = MAX_RETRIES, **kwargs: Any, ): + self._max_retries = max_retries super().__init__(base_url=base_url, headers=headers, timeout=timeout, **kwargs) @property @@ -204,7 +212,7 @@ async def _request_cast( import asyncio combined_headers = {**self.default_headers, **(headers or {})} - retries_left = MAX_RETRIES + retries_left = self._max_retries delay = INITIAL_RETRY_DELAY while True: diff --git a/src/layerlens/_client.py b/src/layerlens/_client.py index 873ab4c..33fae85 100644 --- a/src/layerlens/_client.py +++ b/src/layerlens/_client.py @@ -45,6 +45,7 @@ def __init__( api_key: str | None = None, base_url: str | httpx.URL | None = None, timeout: Union[float, httpx.Timeout, None] = DEFAULT_TIMEOUT, + max_retries: int = 2, use_bearer_auth: bool = False, ) -> None: """Construct a new synchronous Stratix client instance. @@ -73,6 +74,7 @@ def __init__( super().__init__( base_url=base_url, timeout=timeout, + max_retries=max_retries, ) organization = self._get_organization() @@ -270,6 +272,7 @@ def __init__( api_key: str | None = None, base_url: str | httpx.URL | None = None, timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, + max_retries: int = 2, use_bearer_auth: bool = False, ) -> None: """Construct a new asynchronous Stratix client instance. @@ -295,7 +298,7 @@ def __init__( if use_bearer_auth: base_url = str(base_url).rstrip("/") + DIRTY_ROUTER_PREFIX - super().__init__(base_url=base_url, timeout=timeout) + super().__init__(base_url=base_url, timeout=timeout, max_retries=max_retries) organization = self._get_organization() if organization is None: diff --git a/src/layerlens/models/integration.py b/src/layerlens/models/integration.py index 5262259..8b18ffd 100644 --- a/src/layerlens/models/integration.py +++ b/src/layerlens/models/integration.py @@ -8,9 +8,13 @@ class Integration(BaseModel): id: str organization_id: str - project_id: str name: str type: Optional[str] = None - status: Optional[str] = None + host_url: Optional[str] = None + active: Optional[bool] = None created_at: Optional[str] = None + created_by: Optional[str] = None + # Legacy/convenience aliases kept optional for backwards compatibility + project_id: Optional[str] = None + status: Optional[str] = None config: Dict[str, Any] = {} diff --git a/src/layerlens/resources/integrations/integrations.py b/src/layerlens/resources/integrations/integrations.py index ecd856b..bb73556 100644 --- a/src/layerlens/resources/integrations/integrations.py +++ b/src/layerlens/resources/integrations/integrations.py @@ -72,12 +72,20 @@ def get_many( return None data = _unwrap(resp) - if not isinstance(data, dict): + + # The API returns the integrations array directly (wrapped in the + # standard {"status": ..., "data": [...]} envelope). After unwrapping, + # ``data`` is a list of integration dicts. + if isinstance(data, list): + items = data + elif isinstance(data, dict): + items = data.get("integrations", []) + else: return None - integrations = [i if isinstance(i, Integration) else Integration(**i) for i in data.get("integrations", [])] - count: int = data.get("count", len(integrations)) - total_count: int = data.get("total_count", count) + integrations = [i if isinstance(i, Integration) else Integration(**i) for i in items] + count = len(integrations) + total_count = count try: return IntegrationsResponse(integrations=integrations, count=count, total_count=total_count) @@ -153,12 +161,20 @@ async def get_many( return None data = _unwrap(resp) - if not isinstance(data, dict): + + # The API returns the integrations array directly (wrapped in the + # standard {"status": ..., "data": [...]} envelope). After unwrapping, + # ``data`` is a list of integration dicts. + if isinstance(data, list): + items = data + elif isinstance(data, dict): + items = data.get("integrations", []) + else: return None - integrations = [i if isinstance(i, Integration) else Integration(**i) for i in data.get("integrations", [])] - count: int = data.get("count", len(integrations)) - total_count: int = data.get("total_count", count) + integrations = [i if isinstance(i, Integration) else Integration(**i) for i in items] + count = len(integrations) + total_count = count try: return IntegrationsResponse(integrations=integrations, count=count, total_count=total_count) diff --git a/src/layerlens/resources/models/models.py b/src/layerlens/resources/models/models.py index 411afed..627b60f 100644 --- a/src/layerlens/resources/models/models.py +++ b/src/layerlens/resources/models/models.py @@ -170,8 +170,10 @@ def add( timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, ) -> bool: """Add models to the project by their IDs.""" - current = self.get(timeout=timeout) or [] - current_ids = [m.id for m in current] + # Only fetch public (platform) models — custom models are managed + # separately and must not be included in the project patch payload. + current = self.get(timeout=timeout, type="public") or [] + current_ids = [str(m.id) for m in current] new_ids = list(dict.fromkeys(current_ids + list(model_ids))) return self._patch_project_models(new_ids, timeout) @@ -181,9 +183,11 @@ def remove( timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, ) -> bool: """Remove models from the project by their IDs.""" - current = self.get(timeout=timeout) or [] + # Only fetch public (platform) models — custom models are managed + # separately and must not be included in the project patch payload. + current = self.get(timeout=timeout, type="public") or [] remove_set = set(model_ids) - new_ids = [m.id for m in current if m.id not in remove_set] + new_ids = [str(m.id) for m in current if str(m.id) not in remove_set] return self._patch_project_models(new_ids, timeout) def _patch_project_models( @@ -198,7 +202,11 @@ def _patch_project_models( timeout=timeout, cast_to=dict, ) - return isinstance(resp, dict) and "id" in resp + if isinstance(resp, dict): + data = resp.get("data", resp) + if isinstance(data, dict) and "id" in data: + return True + return False def create_custom( self, @@ -361,8 +369,10 @@ async def add( timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, ) -> bool: """Add models to the project by their IDs.""" - current = await self.get(timeout=timeout) or [] - current_ids = [m.id for m in current] + # Only fetch public (platform) models — custom models are managed + # separately and must not be included in the project patch payload. + current = await self.get(timeout=timeout, type="public") or [] + current_ids = [str(m.id) for m in current] new_ids = list(dict.fromkeys(current_ids + list(model_ids))) return await self._patch_project_models(new_ids, timeout) @@ -372,9 +382,11 @@ async def remove( timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, ) -> bool: """Remove models from the project by their IDs.""" - current = await self.get(timeout=timeout) or [] + # Only fetch public (platform) models — custom models are managed + # separately and must not be included in the project patch payload. + current = await self.get(timeout=timeout, type="public") or [] remove_set = set(model_ids) - new_ids = [m.id for m in current if m.id not in remove_set] + new_ids = [str(m.id) for m in current if str(m.id) not in remove_set] return await self._patch_project_models(new_ids, timeout) async def _patch_project_models( @@ -389,7 +401,11 @@ async def _patch_project_models( timeout=timeout, cast_to=dict, ) - return isinstance(resp, dict) and "id" in resp + if isinstance(resp, dict): + data = resp.get("data", resp) + if isinstance(data, dict) and "id" in data: + return True + return False async def create_custom( self, diff --git a/src/layerlens/resources/trace_evaluations/trace_evaluations.py b/src/layerlens/resources/trace_evaluations/trace_evaluations.py index 702b20e..255b246 100644 --- a/src/layerlens/resources/trace_evaluations/trace_evaluations.py +++ b/src/layerlens/resources/trace_evaluations/trace_evaluations.py @@ -1,5 +1,7 @@ from __future__ import annotations +import time +import asyncio from typing import Any, Dict, List, Optional import httpx @@ -7,11 +9,18 @@ from ...models import ( TraceEvaluation, CostEstimateResponse, + TraceEvaluationStatus, TraceEvaluationsResponse, TraceEvaluationResultsResponse, ) from ..._resource import SyncAPIResource, AsyncAPIResource from ..._constants import DEFAULT_TIMEOUT +from ..._exceptions import NotFoundError + +_TERMINAL_STATUSES = { + TraceEvaluationStatus.SUCCESS, + TraceEvaluationStatus.FAILURE, +} DEFAULT_PAGE = 1 DEFAULT_PAGE_SIZE = 20 @@ -136,11 +145,14 @@ def get_results( *, timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, ) -> Optional[TraceEvaluationResultsResponse]: - resp = self._get( - f"{self._base_url()}/{id}/results", - timeout=timeout, - cast_to=dict, - ) + try: + resp = self._get( + f"{self._base_url()}/{id}/results", + timeout=timeout, + cast_to=dict, + ) + except NotFoundError: + return None data = _unwrap(resp) if not data or not isinstance(data, dict): return None @@ -150,6 +162,26 @@ def get_results( except Exception: return None + def wait_for_completion( + self, + id: str, + *, + interval_seconds: int = 3, + timeout_seconds: Optional[int] = 300, + ) -> Optional[TraceEvaluationResultsResponse]: + """Poll until the trace evaluation finishes, then return results.""" + start = time.time() + + while True: + te = self.get(id) + if te and te.status in _TERMINAL_STATUSES: + break + if timeout_seconds and (time.time() - start) > timeout_seconds: + raise TimeoutError(f"Trace evaluation {id} did not complete within {timeout_seconds} seconds") + time.sleep(interval_seconds) + + return self.get_results(id) + def estimate_cost( self, *, @@ -283,11 +315,14 @@ async def get_results( *, timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT, ) -> Optional[TraceEvaluationResultsResponse]: - resp = await self._get( - f"{self._base_url()}/{id}/results", - timeout=timeout, - cast_to=dict, - ) + try: + resp = await self._get( + f"{self._base_url()}/{id}/results", + timeout=timeout, + cast_to=dict, + ) + except NotFoundError: + return None data = _unwrap(resp) if not data or not isinstance(data, dict): return None @@ -297,6 +332,26 @@ async def get_results( except Exception: return None + async def wait_for_completion( + self, + id: str, + *, + interval_seconds: int = 3, + timeout_seconds: Optional[int] = 300, + ) -> Optional[TraceEvaluationResultsResponse]: + """Poll until the trace evaluation finishes, then return results.""" + start = asyncio.get_event_loop().time() + + while True: + te = await self.get(id) + if te and te.status in _TERMINAL_STATUSES: + break + if timeout_seconds and (asyncio.get_event_loop().time() - start) > timeout_seconds: + raise TimeoutError(f"Trace evaluation {id} did not complete within {timeout_seconds} seconds") + await asyncio.sleep(interval_seconds) + + return await self.get_results(id) + async def estimate_cost( self, *, diff --git a/tests/resources/test_trace_evaluations.py b/tests/resources/test_trace_evaluations.py index a445bc1..b68b664 100644 --- a/tests/resources/test_trace_evaluations.py +++ b/tests/resources/test_trace_evaluations.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest @@ -10,6 +10,7 @@ TraceEvaluationResultsResponse, ) from layerlens._constants import DEFAULT_TIMEOUT +from layerlens._exceptions import NotFoundError from layerlens.resources.trace_evaluations.trace_evaluations import TraceEvaluations @@ -67,7 +68,11 @@ def sample_result_data(self): "reasoning": "The output meets quality standards", "steps": [ {"tool": "jq", "args": {"query": "."}, "result": "Checked correctness"}, - {"tool": "submit_evaluation", "args": {"score": 0.85}, "result": "Checked style"}, + { + "tool": "submit_evaluation", + "args": {"score": 0.85}, + "result": "Checked style", + }, ], "model": "claude-sonnet-4-20250514", "turns": 3, @@ -437,3 +442,199 @@ def test_estimate_url(self, trace_evals_resource): call_args = trace_evals_resource._post.call_args assert call_args[0][0] == "/organizations/custom-org/projects/custom-proj/trace-evaluations/estimate" + + +class TestGetResultsNotFoundHandling: + """Test that get_results returns None on 404 instead of raising.""" + + @pytest.fixture + def mock_client(self): + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + client.post_cast = Mock() + client.patch_cast = Mock() + client.delete_cast = Mock() + return client + + @pytest.fixture + def trace_evals_resource(self, mock_client): + return TraceEvaluations(mock_client) + + def test_get_results_returns_none_on_404(self, trace_evals_resource): + """get_results returns None when evaluation results don't exist yet (404).""" + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {} + trace_evals_resource._get.side_effect = NotFoundError("Not found", response=mock_response, body=None) + + result = trace_evals_resource.get_results("te-pending") + + assert result is None + + +class TestWaitForCompletion: + """Test wait_for_completion polling behavior.""" + + @pytest.fixture + def mock_client(self): + client = Mock() + client.organization_id = "org-123" + client.project_id = "proj-456" + client.get_cast = Mock() + client.post_cast = Mock() + client.patch_cast = Mock() + client.delete_cast = Mock() + return client + + @pytest.fixture + def trace_evals_resource(self, mock_client): + return TraceEvaluations(mock_client) + + @pytest.fixture + def sample_result_data(self): + return { + "id": "result-123", + "trace_evaluation_id": "te-123", + "trace_id": "trace-456", + "judge_id": "judge-789", + "score": 0.85, + "passed": True, + "reasoning": "Good output", + "steps": [], + "model": "claude-sonnet-4-20250514", + "turns": 3, + "latency_ms": 2500, + "prompt_tokens": 1500, + "completion_tokens": 300, + "total_cost": 0.0045, + "created_at": "2024-01-01T00:00:05Z", + } + + @patch("layerlens.resources.trace_evaluations.trace_evaluations.time.sleep") + def test_wait_returns_results_on_success(self, mock_sleep, trace_evals_resource, sample_result_data): + """wait_for_completion returns results when evaluation succeeds.""" + pending = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "pending", + } + success = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "success", + } + + trace_evals_resource._get.side_effect = [ + pending, # first poll → pending + success, # second poll → success + sample_result_data, # get_results call + ] + + result = trace_evals_resource.wait_for_completion("te-123", interval_seconds=1) + + assert isinstance(result, TraceEvaluationResultsResponse) + assert result.score == 0.85 + assert result.passed is True + assert mock_sleep.call_count == 1 + + @patch("layerlens.resources.trace_evaluations.trace_evaluations.time.sleep") + def test_wait_returns_none_on_failure(self, mock_sleep, trace_evals_resource): + """wait_for_completion returns None when evaluation fails (no results).""" + failure = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "failure", + } + + mock_response = Mock() + mock_response.status_code = 404 + mock_response.headers = {} + + trace_evals_resource._get.side_effect = [ + failure, # first poll → failure + NotFoundError("Not found", response=mock_response, body=None), # get_results → 404 + ] + + result = trace_evals_resource.wait_for_completion("te-123") + + assert result is None + assert mock_sleep.call_count == 0 + + @patch("layerlens.resources.trace_evaluations.trace_evaluations.time.time") + @patch("layerlens.resources.trace_evaluations.trace_evaluations.time.sleep") + def test_wait_raises_timeout(self, _mock_sleep, mock_time, trace_evals_resource): + """wait_for_completion raises TimeoutError when timeout exceeded.""" + mock_time.side_effect = [ + 0, + 0, + 301, + ] # start, first check ok, second check exceeds 300s + + pending = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "pending", + } + trace_evals_resource._get.return_value = pending + + with pytest.raises(TimeoutError, match="did not complete within 300 seconds"): + trace_evals_resource.wait_for_completion("te-123", timeout_seconds=300) + + @patch("layerlens.resources.trace_evaluations.trace_evaluations.time.sleep") + def test_wait_polls_through_in_progress(self, mock_sleep, trace_evals_resource, sample_result_data): + """wait_for_completion polls through pending and in_progress states.""" + pending = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "pending", + } + in_progress = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "in_progress", + } + success = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "success", + } + + trace_evals_resource._get.side_effect = [ + pending, + in_progress, + success, + sample_result_data, + ] + + result = trace_evals_resource.wait_for_completion("te-123", interval_seconds=1) + + assert isinstance(result, TraceEvaluationResultsResponse) + assert mock_sleep.call_count == 2 + + @patch("layerlens.resources.trace_evaluations.trace_evaluations.time.sleep") + def test_wait_no_timeout_when_none(self, _mock_sleep, trace_evals_resource, sample_result_data): + """wait_for_completion runs indefinitely when timeout_seconds=None.""" + success = { + "id": "te-123", + "trace_id": "t-1", + "judge_id": "j-1", + "status": "success", + } + + trace_evals_resource._get.side_effect = [ + success, + sample_result_data, + ] + + result = trace_evals_resource.wait_for_completion("te-123", timeout_seconds=None) + + assert isinstance(result, TraceEvaluationResultsResponse) diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 179d23d..2eb6ca6 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -5,7 +5,7 @@ import pytest from layerlens import _exceptions -from layerlens._base_client import BaseClient +from layerlens._base_client import MAX_RETRIES, BaseClient @dataclass @@ -232,3 +232,145 @@ def test_make_status_error_not_implemented(self, client): with pytest.raises(NotImplementedError): client._make_status_error("test", body=None, response=mock_response) + + def test_default_max_retries(self): + """BaseClient defaults to MAX_RETRIES.""" + client = BaseClient(base_url="https://api.test.com") + + assert client._max_retries == MAX_RETRIES + + def test_custom_max_retries(self): + """BaseClient accepts custom max_retries.""" + client = BaseClient(base_url="https://api.test.com", max_retries=5) + + assert client._max_retries == 5 + + def test_zero_max_retries_disables_retries(self): + """max_retries=0 disables automatic retries.""" + client = BaseClient(base_url="https://api.test.com", max_retries=0) + + assert client._max_retries == 0 + + @patch("layerlens._base_client.time.sleep") + @patch("httpx.Client.request") + def test_retries_on_429(self, mock_request, mock_sleep, client): + """Client retries on 429 and succeeds on subsequent attempt.""" + rate_limited = Mock(spec=httpx.Response) + rate_limited.status_code = 429 + rate_limited.headers = {} + + success = Mock(spec=httpx.Response) + success.status_code = 200 + success.raise_for_status.return_value = None + success.json.return_value = {"name": "ok", "value": 1} + + mock_request.side_effect = [rate_limited, success] + + result = client._request_cast("GET", "/test", cast_to=ResponseModel) + + assert isinstance(result, ResponseModel) + assert mock_request.call_count == 2 + assert mock_sleep.call_count == 1 + + @patch("layerlens._base_client.time.sleep") + @patch("httpx.Client.request") + def test_retries_respect_retry_after_header(self, mock_request, mock_sleep, client): + """Client uses Retry-After header value for sleep duration.""" + rate_limited = Mock(spec=httpx.Response) + rate_limited.status_code = 429 + rate_limited.headers = {"retry-after": "2"} + + success = Mock(spec=httpx.Response) + success.status_code = 200 + success.raise_for_status.return_value = None + + mock_request.side_effect = [rate_limited, success] + + client._request_cast("GET", "/test") + + mock_sleep.assert_called_once_with(2.0) + + @patch("layerlens._base_client.time.sleep") + @patch("httpx.Client.request") + def test_retries_exhaust_then_raise(self, mock_request, _mock_sleep): + """Client raises after exhausting all retries.""" + client = BaseClient(base_url="https://api.test.com", max_retries=1) + + rate_limited = Mock(spec=httpx.Response) + rate_limited.status_code = 429 + rate_limited.headers = {} + rate_limited.text = '{"message": "Too Many Requests"}' + rate_limited.raise_for_status.side_effect = httpx.HTTPStatusError("429", request=Mock(), response=rate_limited) + + mock_request.return_value = rate_limited + + with patch.object(client, "_make_status_error_from_response") as mock_make_error: + mock_make_error.side_effect = _exceptions.RateLimitError("Rate limited", response=rate_limited, body=None) + + with pytest.raises(_exceptions.RateLimitError): + client._request_cast("GET", "/test") + + # 1 initial + 1 retry = 2 calls + assert mock_request.call_count == 2 + + @patch("layerlens._base_client.time.sleep") + @patch("httpx.Client.request") + def test_no_retries_when_max_retries_zero(self, mock_request, _mock_sleep): + """max_retries=0 means no retries at all.""" + client = BaseClient(base_url="https://api.test.com", max_retries=0) + + rate_limited = Mock(spec=httpx.Response) + rate_limited.status_code = 429 + rate_limited.headers = {} + rate_limited.text = '{"message": "Too Many Requests"}' + rate_limited.raise_for_status.side_effect = httpx.HTTPStatusError("429", request=Mock(), response=rate_limited) + + mock_request.return_value = rate_limited + + with patch.object(client, "_make_status_error_from_response") as mock_make_error: + mock_make_error.side_effect = _exceptions.RateLimitError("Rate limited", response=rate_limited, body=None) + + with pytest.raises(_exceptions.RateLimitError): + client._request_cast("GET", "/test") + + assert mock_request.call_count == 1 + + @patch("layerlens._base_client.time.sleep") + @patch("httpx.Client.request") + def test_retries_on_500(self, mock_request, mock_sleep, client): + """Client retries on 500 server errors.""" + server_error = Mock(spec=httpx.Response) + server_error.status_code = 500 + server_error.headers = {} + + success = Mock(spec=httpx.Response) + success.status_code = 200 + success.raise_for_status.return_value = None + + mock_request.side_effect = [server_error, success] + + client._request_cast("GET", "/test") + + assert mock_request.call_count == 2 + assert mock_sleep.call_count == 1 + + @patch("layerlens._base_client.time.sleep") + @patch("httpx.Client.request") + def test_custom_max_retries_allows_more_attempts(self, mock_request, mock_sleep): + """Custom max_retries=4 allows up to 4 retries.""" + client = BaseClient(base_url="https://api.test.com", max_retries=4) + + server_error = Mock(spec=httpx.Response) + server_error.status_code = 502 + server_error.headers = {} + + success = Mock(spec=httpx.Response) + success.status_code = 200 + success.raise_for_status.return_value = None + + mock_request.side_effect = [server_error, server_error, server_error, success] + + client._request_cast("GET", "/test") + + assert mock_request.call_count == 4 + assert mock_sleep.call_count == 3