Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions src/guidellm/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(
self,
target: str,
model: str = "",
api_key: str | None = None,
bearer_token: str | None = None,
headers: dict[str, str] | None = None,
api_routes: dict[str, str] | None = None,
response_handlers: dict[str, Any] | None = None,
timeout: float = 60.0,
Expand All @@ -65,6 +68,9 @@ def __init__(

:param target: Base URL of the OpenAI-compatible server
:param model: Model identifier for generation requests
:param api_key: API key for authentication (used as Bearer token)
:param bearer_token: Bearer token for authentication (alternative to api_key)
:param headers: Additional headers to include in all requests
:param api_routes: Custom API endpoint routes mapping
:param response_handlers: Custom response handlers for different request types
:param timeout: Request timeout in seconds
Expand All @@ -79,6 +85,29 @@ def __init__(
self.target = target.rstrip("/").removesuffix("/v1")
self.model = model

# Build default headers with authentication
from guidellm.settings import settings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't late import. This should be in import section at top.


self._default_headers: dict[str, str] = {}

# Merge headers from settings first (lowest priority)
if settings.openai.headers:
self._default_headers.update(settings.openai.headers)

# Add explicit headers parameter (medium priority)
if headers:
self._default_headers.update(headers)

# Resolve API key (highest priority): explicit param > settings
resolved_api_key = api_key or settings.openai.api_key
resolved_bearer_token = bearer_token or settings.openai.bearer_token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you actually remove all references to bearer_token. I don't think we need both.


# Set Authorization header if we have credentials
if resolved_api_key:
self._default_headers["Authorization"] = f"Bearer {resolved_api_key}"
elif resolved_bearer_token:
self._default_headers["Authorization"] = f"Bearer {resolved_bearer_token}"

# Store configuration
self.api_routes = api_routes or {
"health": "health",
Expand Down Expand Up @@ -184,7 +213,7 @@ async def available_models(self) -> list[str]:
raise RuntimeError("Backend not started up for process.")

target = f"{self.target}/{self.api_routes['models']}"
response = await self._async_client.get(target)
response = await self._async_client.get(target, headers=self._default_headers)
response.raise_for_status()

return [item["id"] for item in response.json()["data"]]
Expand Down Expand Up @@ -245,13 +274,19 @@ async def resolve( # type: ignore[override]
request.request_type, handler_overrides=self.response_handlers
)

# Merge default headers with request-specific headers
merged_headers = {
**self._default_headers,
**(request.arguments.headers or {}),
}

if not request.arguments.stream:
request_info.timings.request_start = time.time()
response = await self._async_client.request(
request.arguments.method or "POST",
request_url,
params=request.arguments.params,
headers=request.arguments.headers,
headers=merged_headers,
json=request_json,
data=request_data,
files=request_files,
Expand All @@ -269,7 +304,7 @@ async def resolve( # type: ignore[override]
request.arguments.method or "POST",
request_url,
params=request.arguments.params,
headers=request.arguments.headers,
headers=merged_headers,
json=request_json,
data=request_data,
files=request_files,
Expand Down Expand Up @@ -331,4 +366,9 @@ def _resolve_validate_kwargs(
if "method" not in validate_kwargs:
validate_kwargs["method"] = "GET"

# Include default headers (with auth) in validation request
if self._default_headers:
existing_headers = validate_kwargs.get("headers", {})
validate_kwargs["headers"] = {**self._default_headers, **existing_headers}

return validate_kwargs