Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/get_benchmark_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""Fetch prompts from a benchmark (custom or public)."""

from layerlens import Stratix


def main():
client = Stratix()

# Find a benchmark with prompts
benchmarks = client.benchmarks.get()
benchmark = next((b for b in benchmarks if b.prompt_count and b.prompt_count > 0), None)
if benchmark is None:
print("No benchmarks with prompts found.")
return

print(f"Benchmark: {benchmark.name} ({benchmark.key})")
print(f"Total prompts: {benchmark.prompt_count}\n")

# --- Get a single page of prompts
page = client.benchmarks.get_prompts(benchmark.id, page=1, page_size=5)
if page:
print(f"Page 1 ({len(page.prompts)} of {page.count}):")
for p in page.prompts:
inp = str(p.input)[:80]
print(f" [{p.id}] {inp}")

# --- Get all prompts (auto-paginated)
all_prompts = client.benchmarks.get_all_prompts(benchmark.id)
print(f"\nAll prompts fetched: {len(all_prompts)}")

# --- Search and sort
results = client.benchmarks.get_prompts(
benchmark.id,
search_field="truth",
search_value="the",
sort_by="id",
sort_order="asc",
page_size=3,
)
if results:
print(f"\nSearch results ({results.count} matches):")
for p in results.prompts:
print(f" [{p.id}] truth: {p.truth[:60]}")


if __name__ == "__main__":
main()
174 changes: 174 additions & 0 deletions src/layerlens/resources/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from __future__ import annotations

import os
import math
import mimetypes
from typing import Any, Dict, List, Literal, Optional

import httpx

from ...models import (
Benchmark,
BenchmarkPrompt,
CustomBenchmark,
PublicBenchmark,
BenchmarksResponse,
BenchmarkPromptsData,
CreateBenchmarkResponse,
)
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._constants import DEFAULT_TIMEOUT

DEFAULT_PROMPTS_PAGE_SIZE = 100

MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB


Expand Down Expand Up @@ -163,6 +168,99 @@ def remove(
new_ids = [b.id for b in current if b.id not in remove_set]
return self._patch_project_benchmarks(new_ids, timeout)

def get_prompts(
self,
benchmark_id: str,
*,
page: Optional[int] = None,
page_size: Optional[int] = None,
search_field: Optional[Literal["id", "input", "truth"]] = None,
search_value: Optional[str] = None,
sort_by: Optional[Literal["id", "input", "truth"]] = None,
sort_order: Optional[Literal["asc", "desc"]] = None,
timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT,
) -> Optional[BenchmarkPromptsData]:
"""Fetch a page of prompts for a benchmark.

Uses the org-scoped endpoint:
GET /organizations/{org}/projects/{proj}/benchmarks/{id}/prompts

Args:
benchmark_id: The benchmark / dataset ID.
page: Page number (1-based).
page_size: Number of prompts per page.
search_field: Field to search in.
search_value: Value to search for.
sort_by: Field to sort by.
sort_order: Sort direction.
timeout: Request timeout override.

Returns:
BenchmarkPromptsData with prompts list and count, or None on failure.
"""
params: Dict[str, str] = {}
if page is not None:
params["page"] = str(page)
if page_size is not None:
params["page_size"] = str(page_size)
if search_field:
params["search"] = search_field
if search_value:
params["search_value"] = search_value
if sort_by:
params["sort_by"] = sort_by
if sort_order:
params["sort_order"] = sort_order

url = f"/organizations/{self._client.organization_id}/projects/{self._client.project_id}/benchmarks/{benchmark_id}/prompts"
resp = self._get(
url,
params=params,
timeout=timeout,
cast_to=dict,
)

if not isinstance(resp, dict):
return None

# Unwrap {"status": ..., "data": {...}} envelope if present
if "data" in resp and "status" in resp:
resp = resp["data"]

return BenchmarkPromptsData.model_validate(resp)

def get_all_prompts(
self,
benchmark_id: str,
*,
timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT,
) -> List[BenchmarkPrompt]:
"""Fetch all prompts for a benchmark, automatically paginating."""
all_prompts: List[BenchmarkPrompt] = []
page = 1
page_size = DEFAULT_PROMPTS_PAGE_SIZE

while True:
resp = self.get_prompts(
benchmark_id,
page=page,
page_size=page_size,
timeout=timeout,
)
if resp is None or not resp.prompts:
break

all_prompts.extend(resp.prompts)

total_count = resp.count
total_pages = math.ceil(total_count / page_size) if total_count > 0 else 0
if page >= total_pages:
break

page += 1

return all_prompts

def _patch_project_benchmarks(
self,
dataset_ids: List[str],
Expand Down Expand Up @@ -452,6 +550,82 @@ async def remove(
new_ids = [b.id for b in current if b.id not in remove_set]
return await self._patch_project_benchmarks(new_ids, timeout)

async def get_prompts(
self,
benchmark_id: str,
*,
page: Optional[int] = None,
page_size: Optional[int] = None,
search_field: Optional[Literal["id", "input", "truth"]] = None,
search_value: Optional[str] = None,
sort_by: Optional[Literal["id", "input", "truth"]] = None,
sort_order: Optional[Literal["asc", "desc"]] = None,
timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT,
) -> Optional[BenchmarkPromptsData]:
"""Fetch a page of prompts for a benchmark."""
params: Dict[str, str] = {}
if page is not None:
params["page"] = str(page)
if page_size is not None:
params["page_size"] = str(page_size)
if search_field:
params["search"] = search_field
if search_value:
params["search_value"] = search_value
if sort_by:
params["sort_by"] = sort_by
if sort_order:
params["sort_order"] = sort_order

url = f"/organizations/{self._client.organization_id}/projects/{self._client.project_id}/benchmarks/{benchmark_id}/prompts"
resp = await self._get(
url,
params=params,
timeout=timeout,
cast_to=dict,
)

if not isinstance(resp, dict):
return None

# Unwrap {"status": ..., "data": {...}} envelope if present
if "data" in resp and "status" in resp:
resp = resp["data"]

return BenchmarkPromptsData.model_validate(resp)

async def get_all_prompts(
self,
benchmark_id: str,
*,
timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT,
) -> List[BenchmarkPrompt]:
"""Fetch all prompts for a benchmark, automatically paginating."""
all_prompts: List[BenchmarkPrompt] = []
page = 1
page_size = DEFAULT_PROMPTS_PAGE_SIZE

while True:
resp = await self.get_prompts(
benchmark_id,
page=page,
page_size=page_size,
timeout=timeout,
)
if resp is None or not resp.prompts:
break

all_prompts.extend(resp.prompts)

total_count = resp.count
total_pages = math.ceil(total_count / page_size) if total_count > 0 else 0
if page >= total_pages:
break

page += 1

return all_prompts

async def _patch_project_benchmarks(
self,
dataset_ids: List[str],
Expand Down
Loading
Loading