Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 129 additions & 9 deletions src/harness_sdk/instrumentation/litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
- litellm.completion / litellm.acompletion
- litellm.embedding / litellm.aembedding

Registers a ``TraceableLiteLLMOpenTelemetry`` callback (subclass of LiteLLM's
``OpenTelemetry``) and wraps the public entry points so evaluation runs on an
active span before the provider call. LiteLLM's OTEL callback enriches that
span on success when it is the active parent context.
Wraps the public entry points so evaluation runs on an active span before the
provider call. The wrapper enriches that span with response metadata before it
ends.

Optional: ``pip install harness-sdk[litellm]``
"""
Expand Down Expand Up @@ -43,6 +42,19 @@
("aembedding", True),
)

_PROVIDER_NAME_MAP = {
"azure": "azure.ai.openai",
"azure_ai": "azure.ai.openai",
"azure_ai_openai": "azure.ai.openai",
"azureopenai": "azure.ai.openai",
"bedrock": "aws.bedrock",
"bedrock_converse": "aws.bedrock",
"gemini": "gcp.gemini",
"google": "gcp.gemini",
"vertex_ai": "gcp.vertex_ai",
"vertexai": "gcp.vertex_ai",
}


def _evaluate_span(span: Any) -> None:
"""Run Traceable policy evaluation against the live span; raise if blocked."""
Expand Down Expand Up @@ -111,6 +123,11 @@ def _resolve_provider(model: Optional[str], kwargs: dict[str, Any]) -> str:
return "Unknown"


def _canonical_provider_name(provider: str) -> str:
normalized = (provider or "unknown").strip().lower().replace("-", "_")
return _PROVIDER_NAME_MAP.get(normalized, normalized)


def _operation_name(func_name: str) -> str:
if func_name in ("embedding", "aembedding"):
return "embeddings"
Expand Down Expand Up @@ -138,7 +155,9 @@ def _set_pre_call_request_attributes(
otel_logger.safe_set_attribute(
span, "gen_ai.operation.name", _operation_name(pre_call.call_type)
)
otel_logger.safe_set_attribute(span, "gen_ai.system", provider)
otel_logger.safe_set_attribute(
span, "gen_ai.provider.name", _canonical_provider_name(provider)
)
otel_logger.safe_set_attribute(span, "gen_ai.framework", "litellm")
otel_logger.safe_set_attribute(
span,
Expand Down Expand Up @@ -182,6 +201,105 @@ def _set_pre_call_request_attributes(
logger.debug("LiteLLM: failed to set input attributes on span: %s", err)


def _get_value(obj: Any, key: str) -> Any:
if obj is None:
return None
if isinstance(obj, dict):
return obj.get(key)
return getattr(obj, key, None)


def _get_usage(response: Any) -> Any:
usage = _get_value(response, "usage")
if usage is not None:
return usage
if isinstance(response, dict):
return response.get("usage")
return None


def _set_if_present(otel_logger: Any, span: Any, key: str, value: Any) -> None:
if value is not None:
otel_logger.safe_set_attribute(span, key, value)


def _get_choices(response: Any) -> list[Any]:
choices = _get_value(response, "choices")
if choices is None:
return []
return list(choices)


def _get_finish_reasons(response: Any) -> list[str]:
finish_reasons = []
for choice in _get_choices(response):
finish_reason = _get_value(choice, "finish_reason")
if finish_reason:
finish_reasons.append(str(finish_reason))
return finish_reasons


def _set_response_attributes(otel_logger: Any, span: Any, response: Any) -> None:
"""Copy LiteLLM response metadata before the wrapper-owned span ends."""
_set_if_present(
otel_logger,
span,
"gen_ai.response.model",
_get_value(response, "model"),
)
_set_if_present(otel_logger, span, "gen_ai.response.id", _get_value(response, "id"))

finish_reasons = _get_finish_reasons(response)
if finish_reasons:
otel_logger.safe_set_attribute(
span, "gen_ai.response.finish_reasons", finish_reasons
)

usage = _get_usage(response)
if usage is None:
return

prompt_details = _get_value(usage, "prompt_tokens_details")
completion_details = _get_value(usage, "completion_tokens_details")

input_tokens = _get_value(usage, "prompt_tokens")
if input_tokens is None:
input_tokens = _get_value(usage, "input_tokens")

output_tokens = _get_value(usage, "completion_tokens")
if output_tokens is None:
output_tokens = _get_value(usage, "output_tokens")

_set_if_present(otel_logger, span, "gen_ai.usage.input_tokens", input_tokens)
_set_if_present(otel_logger, span, "gen_ai.usage.output_tokens", output_tokens)
_set_if_present(
otel_logger,
span,
"gen_ai.usage.total_tokens",
_get_value(usage, "total_tokens"),
)
_set_if_present(
otel_logger,
span,
"gen_ai.usage.cache_read.input_tokens",
_get_value(usage, "cache_read_input_tokens")
or _get_value(prompt_details, "cached_tokens"),
)
_set_if_present(
otel_logger,
span,
"gen_ai.usage.cache_creation.input_tokens",
_get_value(usage, "cache_creation_input_tokens")
or _get_value(prompt_details, "cache_creation_tokens"),
)
_set_if_present(
otel_logger,
span,
"gen_ai.usage.reasoning.output_tokens",
_get_value(completion_details, "reasoning_tokens"),
)


def _build_traceable_otel_class() -> type:
from litellm.integrations.opentelemetry import ( # pylint: disable=import-outside-toplevel
OpenTelemetry,
Expand Down Expand Up @@ -304,7 +422,9 @@ def _sync_wrapper(
span = _start_evaluated_span(otel_logger, func_name, args, kwargs)
token = _activate_span(span)
try:
return wrapped(*args, **kwargs)
response = wrapped(*args, **kwargs)
_set_response_attributes(otel_logger, span, response)
return response
except Exception as exc: # pylint: disable=broad-except
span.record_exception(exc)
span.set_status(Status(StatusCode.ERROR, str(exc)))
Expand All @@ -324,7 +444,9 @@ async def _async_wrapper(
span = _start_evaluated_span(otel_logger, func_name, args, kwargs)
token = _activate_span(span)
try:
return await wrapped(*args, **kwargs)
response = await wrapped(*args, **kwargs)
_set_response_attributes(otel_logger, span, response)
return response
except Exception as exc: # pylint: disable=broad-except
span.record_exception(exc)
span.set_status(Status(StatusCode.ERROR, str(exc)))
Expand Down Expand Up @@ -353,8 +475,6 @@ def instrument(self, **_kwargs: Any) -> None:
try:
import litellm # pylint: disable=import-outside-toplevel

otel_logger = _get_otel_logger()
_register_otel_callback(otel_logger)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you test this, I guess this is required for the instrumentation.

main_mod = __import__(_LITELLM_MAIN, fromlist=["*"])
for func_name, is_async in _WRAPPED_FUNCTIONS:
wrapt.wrap_function_wrapper(
Expand Down
48 changes: 41 additions & 7 deletions test/instrumentation/litellm/litellm_instrumentation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ def _fake_model_response(*_args, **_kwargs):
}
],
model="gpt-4o-mini",
usage={"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
usage={
"prompt_tokens": 3,
"completion_tokens": 5,
"total_tokens": 8,
"prompt_tokens_details": {
"cached_tokens": 1,
"cache_creation_tokens": 2,
},
"completion_tokens_details": {"reasoning_tokens": 1},
},
)


Expand Down Expand Up @@ -66,11 +75,21 @@ def test_litellm_completion_span_has_gen_ai_attributes(agent, exporter, litellm_
spans = exporter.get_finished_spans()
exporter.clear()
assert len(spans) >= 1
attrs = spans[0].attributes
attrs = _request_span(spans).attributes
assert attrs.get("gen_ai.request.model") == "gpt-4o-mini"
assert attrs.get("gen_ai.operation.name") == "chat"
assert attrs.get("gen_ai.system") == "openai"
assert attrs.get("gen_ai.provider.name") == "openai"
assert "gen_ai.system" not in attrs
assert attrs.get("gen_ai.framework") == "litellm"
assert attrs.get("gen_ai.response.model") == "gpt-4o-mini"
assert attrs.get("gen_ai.response.id") == "chatcmpl-test"
assert attrs.get("gen_ai.response.finish_reasons") == "['stop']"
assert attrs.get("gen_ai.usage.input_tokens") == 3
assert attrs.get("gen_ai.usage.output_tokens") == 5
assert attrs.get("gen_ai.usage.total_tokens") == 8
assert attrs.get("gen_ai.usage.cache_read.input_tokens") == 1
assert attrs.get("gen_ai.usage.cache_creation.input_tokens") == 2
assert attrs.get("gen_ai.usage.reasoning.output_tokens") == 1


def test_litellm_evaluate_blocks_before_wrapped(agent, exporter, litellm_instrumentor): # pylint: disable=unused-argument
Expand Down Expand Up @@ -105,10 +124,15 @@ def test_litellm_embedding_span_has_gen_ai_attributes(agent, exporter, litellm_i
spans = exporter.get_finished_spans()
exporter.clear()
assert len(spans) >= 1
attrs = spans[0].attributes
attrs = _request_span(spans).attributes
assert attrs.get("gen_ai.request.model") == "text-embedding-3-small"
assert attrs.get("gen_ai.operation.name") == "embeddings"
assert attrs.get("gen_ai.provider.name") == "openai"
assert "gen_ai.system" not in attrs
assert attrs.get("gen_ai.framework") == "litellm"
assert attrs.get("gen_ai.response.model") == "text-embedding-3-small"
assert attrs.get("gen_ai.usage.input_tokens") == 4
assert attrs.get("gen_ai.usage.total_tokens") == 4


@pytest.mark.asyncio
Expand All @@ -126,7 +150,16 @@ async def _fake_async(*_args, **_kwargs):
spans = exporter.get_finished_spans()
exporter.clear()
assert len(spans) >= 1
assert spans[0].attributes.get("gen_ai.operation.name") == "chat"
attrs = _request_span(spans).attributes
assert attrs.get("gen_ai.operation.name") == "chat"
assert attrs.get("gen_ai.provider.name") == "openai"
assert "gen_ai.system" not in attrs
assert attrs.get("gen_ai.response.model") == "gpt-4o-mini"
assert attrs.get("gen_ai.response.id") == "chatcmpl-test"
assert attrs.get("gen_ai.response.finish_reasons") == "['stop']"
assert attrs.get("gen_ai.usage.input_tokens") == 3
assert attrs.get("gen_ai.usage.output_tokens") == 5
assert attrs.get("gen_ai.usage.total_tokens") == 8


def test_litellm_double_instrument_is_noop(agent, exporter, litellm_instrumentor): # pylint: disable=unused-argument
Expand Down Expand Up @@ -172,7 +205,7 @@ def counting_fake(*_a, **_k):
assert len(spans) == 0


def test_litellm_mock_response_with_otel_callback(agent, exporter, litellm_instrumentor): # pylint: disable=unused-argument
def test_litellm_mock_response_with_wrapper_enrichment(agent, exporter, litellm_instrumentor): # pylint: disable=unused-argument
litellm_instrumentor.instrument()
litellm.completion(
model="gpt-4o-mini",
Expand All @@ -186,5 +219,6 @@ def test_litellm_mock_response_with_otel_callback(agent, exporter, litellm_instr
attrs = _request_span(spans).attributes
assert attrs.get("gen_ai.request.model") == "gpt-4o-mini"
assert attrs.get("gen_ai.operation.name") == "chat"
assert attrs.get("gen_ai.system") == "openai"
assert attrs.get("gen_ai.provider.name") == "openai"
assert "gen_ai.system" not in attrs
assert attrs.get("gen_ai.framework") == "litellm"
Loading