Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 Pre-existing: user-provided extra_headers dict is mutated in-place across chat() calls

In livekit-agents/livekit/agents/inference/llm.py:412-413, the _run() method does extra_headers = self._extra_kwargs.setdefault("extra_headers", {}) followed by extra_headers.update(get_inference_headers()). Since chat() at livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py:979-980 assigns extra["extra_headers"] = self._opts.extra_headers (a direct reference, not a copy), repeated chat() calls will accumulate inference headers into the user's original dict object stored in _opts. This is a pre-existing issue (not introduced by this PR) but is now more likely to be encountered since with_azure() users can now pass extra_headers.

(Refers to lines 979-980)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def with_azure(
top_p: NotGivenOr[float] = NOT_GIVEN,
verbosity: NotGivenOr[Verbosity] = NOT_GIVEN,
max_completion_tokens: NotGivenOr[int] = NOT_GIVEN,
store: NotGivenOr[bool] = NOT_GIVEN,
metadata: NotGivenOr[dict[str, str]] = NOT_GIVEN,
prompt_cache_retention: NotGivenOr[PromptCacheRetention] = NOT_GIVEN,
extra_body: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
extra_headers: NotGivenOr[dict[str, str]] = NOT_GIVEN,
extra_query: NotGivenOr[dict[str, str]] = NOT_GIVEN,
) -> LLM:
"""
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
Expand All @@ -220,6 +226,12 @@ def with_azure(
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`

The request-tuning arguments (`store`, `metadata`, `prompt_cache_retention`,
`extra_body`, `extra_headers`, `extra_query`, ...) are forwarded unchanged to the
underlying chat-completions request, mirroring ``LLM.__init__``. They default to
``NOT_GIVEN`` and are only sent when explicitly set. Use `extra_body` to pass
Azure-specific request fields such as ``data_sources`` (Azure OpenAI "On Your Data").
""" # noqa: E501

azure_client = openai.AsyncAzureOpenAI(
Expand Down Expand Up @@ -251,6 +263,12 @@ def with_azure(
top_p=top_p,
verbosity=verbosity,
max_completion_tokens=max_completion_tokens,
store=store,
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.
metadata=metadata,
prompt_cache_retention=prompt_cache_retention,
extra_body=extra_body,
extra_headers=extra_headers,
extra_query=extra_query,
)
llm._owns_client = True
return llm
Expand Down Expand Up @@ -967,6 +985,9 @@ def chat(
if is_given(self._opts.metadata):
extra["metadata"] = self._opts.metadata

if is_given(self._opts.store):
extra["store"] = self._opts.store

if is_given(self._opts.user):
extra["user"] = self._opts.user

Expand Down
90 changes: 90 additions & 0 deletions tests/test_openai_with_azure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Hermetic unit tests for ``openai.LLM.with_azure`` request-parameter forwarding.

These construct the LLM object only and assert on its ``_opts`` — no network or Azure
credentials are required, so the module runs in the ``--unit`` gate.
"""

from __future__ import annotations

import pytest

from livekit.agents.llm import ChatContext
from livekit.agents.types import NOT_GIVEN
from livekit.plugins import openai

pytestmark = pytest.mark.unit

# Dummy Azure connection details. ``AsyncAzureOpenAI`` validates that these are present at
# construction time but does not open a connection, so any non-empty values work here.
_AZURE_ENDPOINT = "https://example.openai.azure.com"
_API_VERSION = "2024-10-21"
_API_KEY = "test-key"


def test_with_azure_forwards_request_params() -> None:
"""Params shared with ``LLM.__init__`` must reach ``_opts`` instead of being dropped."""
extra_body = {"data_sources": [{"type": "azure_search"}]}
extra_headers = {"x-ms-custom": "1"}
extra_query = {"foo": "bar"}
metadata = {"team": "voice"}

azure_llm = openai.LLM.with_azure(
azure_endpoint=_AZURE_ENDPOINT,
api_version=_API_VERSION,
api_key=_API_KEY,
azure_deployment="gpt-4o",
store=True,
metadata=metadata,
prompt_cache_retention="24h",
extra_body=extra_body,
extra_headers=extra_headers,
extra_query=extra_query,
)

opts = azure_llm._opts
assert opts.store is True
assert opts.metadata == metadata
assert opts.prompt_cache_retention == "24h"
assert opts.extra_body == extra_body
assert opts.extra_headers == extra_headers
assert opts.extra_query == extra_query


def test_with_azure_request_params_default_to_not_given() -> None:
"""When omitted, the forwarded params stay ``NOT_GIVEN`` so nothing extra is sent."""
azure_llm = openai.LLM.with_azure(
azure_endpoint=_AZURE_ENDPOINT,
api_version=_API_VERSION,
api_key=_API_KEY,
azure_deployment="gpt-4o",
)

opts = azure_llm._opts
assert opts.store is NOT_GIVEN
assert opts.metadata is NOT_GIVEN
assert opts.prompt_cache_retention is NOT_GIVEN
assert opts.extra_body is NOT_GIVEN
assert opts.extra_headers is NOT_GIVEN
assert opts.extra_query is NOT_GIVEN


@pytest.mark.concurrent
async def test_store_is_forwarded_to_chat_request() -> None:
"""``store`` set on the LLM must actually reach the chat-completions request kwargs."""
azure_llm = openai.LLM(api_key="test-key", store=True)
stream = azure_llm.chat(chat_ctx=ChatContext.empty())
try:
assert stream._extra_kwargs.get("store") is True
finally:
await stream.aclose()


@pytest.mark.concurrent
async def test_store_absent_from_chat_request_when_unset() -> None:
"""When ``store`` is not set, it must not be injected into the request kwargs."""
azure_llm = openai.LLM(api_key="test-key")
stream = azure_llm.chat(chat_ctx=ChatContext.empty())
try:
assert "store" not in stream._extra_kwargs
finally:
await stream.aclose()
Loading