Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from sentry_sdk.tracing import Span
from sentry_sdk._types import TextPart

from openai.types.responses import ResponseInputParam
from openai.types.responses import ResponseInputParam, SequenceNotStr
from openai import Omit

try:
Expand Down Expand Up @@ -220,20 +220,6 @@ def _calculate_token_usage(
)


def _get_input_messages(
kwargs: "dict[str, Any]",
) -> "Optional[Union[Iterable[Any], list[str]]]":
# Input messages (the prompt or data sent to the model)
messages = kwargs.get("messages")
if messages is None:
messages = kwargs.get("input")

if isinstance(messages, str):
messages = [messages]

return messages


def _commmon_set_input_data(
span: "Span",
kwargs: "dict[str, Any]",
Expand Down Expand Up @@ -413,15 +399,47 @@ def _set_embeddings_input_data(
kwargs: "dict[str, Any]",
integration: "OpenAIIntegration",
) -> None:
messages = _get_input_messages(kwargs)
messages: "Union[str, SequenceNotStr[str], Iterable[int], Iterable[Iterable[int]]]" = kwargs.get(
"input"
)

if (
messages is not None
and len(messages) > 0 # type: ignore
and should_send_default_pii()
and integration.include_prompts
not should_send_default_pii()
or not integration.include_prompts
or messages is None
):
normalized_messages = normalize_message_roles(messages) # type: ignore
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)

return

if isinstance(messages, str):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)

normalized_messages = normalize_message_roles([messages]) # type: ignore
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_embedding_inputs(
normalized_messages, span, scope
)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
)

return

# dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
if not isinstance(messages, Iterable) or isinstance(messages, dict):
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
_commmon_set_input_data(span, kwargs)
return

messages = list(messages)
kwargs["input"] = messages

if len(messages) > 0:
normalized_messages = normalize_message_roles(messages)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_embedding_inputs(
normalized_messages, span, scope
Expand Down
229 changes: 219 additions & 10 deletions tests/integrations/openai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,13 @@

@pytest.mark.parametrize(
"send_default_pii, include_prompts",
[(True, True), (True, False), (False, True), (False, False)],
[
(True, False),
(False, True),
(False, False),
],
)
def test_embeddings_create(
def test_embeddings_create_no_pii(
sentry_init, capture_events, send_default_pii, include_prompts
):
sentry_init(
Expand Down Expand Up @@ -966,10 +970,109 @@
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"
if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]

assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30


@pytest.mark.parametrize(
"input",
[
pytest.param(
"hello",
id="string",
),
pytest.param(
["First text", "Second text", "Third text"],
id="string_sequence",
),
pytest.param(
iter(["First text", "Second text", "Third text"]),
id="string_iterable",
),
pytest.param(
[5, 8, 13, 21, 34],
id="tokens",
),
pytest.param(
iter(
[5, 8, 13, 21, 34],
),
id="token_iterable",
),
pytest.param(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
],
id="tokens_sequence",
),
pytest.param(
iter(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]
),
id="tokens_sequence_iterable",
),
],
)
def test_embeddings_create(sentry_init, capture_events, input, request):
sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

client = OpenAI(api_key="z")

returned_embedding = CreateEmbeddingResponse(
data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
model="some-model",
object="list",
usage=EmbeddingTokenUsage(
prompt_tokens=20,
total_tokens=30,
),
)

client.embeddings._post = mock.Mock(return_value=returned_embedding)
with start_transaction(name="openai tx"):
response = client.embeddings.create(input=input, model="text-embedding-3-large")

assert len(response.data[0].embedding) == 3

tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"

param_id = request.node.callspec.id
if param_id == "string":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
elif param_id == "string_sequence" or param_id == "string_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
"First text",
"Second text",
"Third text",
]
elif param_id == "tokens" or param_id == "token_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
5,
8,
13,
21,
34,
]
else:
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30
Expand All @@ -978,9 +1081,13 @@
@pytest.mark.asyncio
@pytest.mark.parametrize(
"send_default_pii, include_prompts",
[(True, True), (True, False), (False, True), (False, False)],
[
(True, False),
(False, True),
(False, False),
],
)
async def test_embeddings_create_async(
async def test_embeddings_create_async_no_pii(
sentry_init, capture_events, send_default_pii, include_prompts
):
sentry_init(
Expand Down Expand Up @@ -1014,10 +1121,112 @@
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"
if send_default_pii and include_prompts:
assert "hello" in span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]

assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30


@pytest.mark.asyncio
@pytest.mark.parametrize(
"input",
[
pytest.param(
"hello",
id="string",
),
pytest.param(
["First text", "Second text", "Third text"],
id="string_sequence",
),
pytest.param(
iter(["First text", "Second text", "Third text"]),
id="string_iterable",

Check warning on line 1145 in tests/integrations/openai/test_openai.py

View workflow job for this annotation

GitHub Actions / warden: find-bugs

Iterator created at test collection time will be exhausted before test runs

The test uses `iter([...])` directly in `pytest.param()`, which creates the iterator at module load/collection time, not at test execution time. In pytest, parameters are evaluated once during collection. If pytest introspects the parameters (e.g., for --collect-only, verbose output, or test reporting), or if tests are re-run, the iterator may be partially or fully consumed before the actual test executes. This makes the test unreliable and may cause false positives/negatives.
),
pytest.param(
[5, 8, 13, 21, 34],
id="tokens",
),
pytest.param(
iter(
[5, 8, 13, 21, 34],
),
id="token_iterable",
),
pytest.param(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
],
id="tokens_sequence",
),
pytest.param(
iter(
[
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]
),
id="tokens_sequence_iterable",
),
],
)
async def test_embeddings_create_async(sentry_init, capture_events, input, request):
sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

client = AsyncOpenAI(api_key="z")

returned_embedding = CreateEmbeddingResponse(
data=[Embedding(object="embedding", index=0, embedding=[1.0, 2.0, 3.0])],
model="some-model",
object="list",
usage=EmbeddingTokenUsage(
prompt_tokens=20,
total_tokens=30,
),
)

client.embeddings._post = AsyncMock(return_value=returned_embedding)
with start_transaction(name="openai tx"):
response = await client.embeddings.create(
input=input, model="text-embedding-3-large"
)

assert len(response.data[0].embedding) == 3

tx = events[0]
assert tx["type"] == "transaction"
span = tx["spans"][0]
assert span["op"] == "gen_ai.embeddings"

param_id = request.node.callspec.id
if param_id == "string":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == ["hello"]
elif param_id == "string_sequence" or param_id == "string_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
"First text",
"Second text",
"Third text",
]
elif param_id == "tokens" or param_id == "token_iterable":
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
5,
8,
13,
21,
34,
]
else:
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
assert json.loads(span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]) == [
[5, 8, 13, 21, 34],
[8, 13, 21, 34, 55],
]

assert span["data"]["gen_ai.usage.input_tokens"] == 20
assert span["data"]["gen_ai.usage.total_tokens"] == 30
Expand Down
Loading