Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Capture Claude Agent SDK session IDs on agent, LLM, and tool spans, and
preserve active caller context so SDK traces attach to existing caller spans
instead of being forced to independent roots.

## Version 0.6.0 (2026-06-03)

There are no changelog entries for this release.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time
from typing import Any, Dict, List, Optional

from opentelemetry import baggage
from opentelemetry import context as otel_context
from opentelemetry.instrumentation.claude_agent_sdk.utils import (
extract_usage_from_result_message,
Expand All @@ -27,7 +28,10 @@
from opentelemetry.trace import set_span_in_context
from opentelemetry.util.genai.extended_handler import (
ExtendedTelemetryHandler,
get_extended_telemetry_handler,
)
from opentelemetry.util.genai.extended_semconv.gen_ai_extended_attributes import (
GEN_AI_SESSION_ID,
GEN_AI_USER_ID,
)
from opentelemetry.util.genai.extended_types import (
ExecuteToolInvocation,
Expand All @@ -46,29 +50,72 @@

logger = logging.getLogger(__name__)

# Storage for tool runs managed by client (created from response stream)
# Key: tool_use_id, Value: tool_invocation
_client_managed_runs: Dict[str, ExecuteToolInvocation] = {}

def _current_baggage_value(key: str) -> Optional[str]:
try:
value = baggage.get_baggage(key)
except Exception:
return None
if value is None:
return None
text = str(value).strip()
return text or None


def _entry_baggage_identity_attributes() -> Dict[str, str]:
attributes: Dict[str, str] = {}
session_id = _current_baggage_value(GEN_AI_SESSION_ID)
user_id = _current_baggage_value(GEN_AI_USER_ID)
if session_id:
attributes[GEN_AI_SESSION_ID] = session_id
if user_id:
attributes[GEN_AI_USER_ID] = user_id
return attributes


def _apply_session_identity(
invocation: Any, session_id: Optional[str]
) -> None:
"""Apply Entry baggage identity first, then Claude's own session fallback."""
entry_attributes = _entry_baggage_identity_attributes()
effective_session_id = (
entry_attributes.get(GEN_AI_SESSION_ID) or session_id
)

if effective_session_id:
if hasattr(invocation, "conversation_id"):
invocation.conversation_id = effective_session_id
invocation.attributes[GEN_AI_SESSION_ID] = effective_session_id

for key, value in entry_attributes.items():
invocation.attributes[key] = value


def _clear_client_managed_runs() -> None:
def _set_session_id(
agent_invocation: InvokeAgentInvocation, session_id: Optional[str]
) -> None:
"""Set Entry session id or Claude session id on an agent invocation."""
_apply_session_identity(agent_invocation, session_id)


def _set_llm_session_id(
llm_invocation: LLMInvocation, session_id: Optional[str]
) -> None:
"""Set Entry session id or Claude session id on an LLM invocation."""
_apply_session_identity(llm_invocation, session_id)


def _clear_client_managed_runs(
handler: ExtendedTelemetryHandler,
client_managed_runs: Dict[str, ExecuteToolInvocation],
) -> None:
"""Clear all client-managed tool runs.

This should be called when a conversation ends to avoid memory leaks
and to clean up any orphaned tool runs.
"""
global _client_managed_runs

try:
handler = get_extended_telemetry_handler()
except Exception:
# If we can't get the handler (e.g., instrumentation not initialized),
# we still need to clear the tracking dictionary to prevent memory leaks.
_client_managed_runs.clear()
return

# End any orphaned tool runs
for tool_use_id, tool_invocation in list(_client_managed_runs.items()):
for tool_use_id, tool_invocation in list(client_managed_runs.items()):
try:
handler.fail_execute_tool(
tool_invocation,
Expand All @@ -83,7 +130,7 @@ def _clear_client_managed_runs() -> None:
# Best effort cleanup: continue processing remaining tools.
pass

_client_managed_runs.clear()
client_managed_runs.clear()


def _extract_message_parts(msg: Any) -> List[Any]:
Expand Down Expand Up @@ -112,6 +159,7 @@ def _create_tool_spans_from_message(
handler: ExtendedTelemetryHandler,
agent_invocation: InvokeAgentInvocation,
active_task_stack: List[Any],
client_managed_runs: Dict[str, ExecuteToolInvocation],
exclude_tool_names: Optional[List[str]] = None,
) -> None:
"""Create tool execution spans from ToolUseBlocks in an AssistantMessage.
Expand Down Expand Up @@ -163,8 +211,11 @@ def _create_tool_spans_from_message(
tool_call_arguments=tool_input,
tool_description=tool_name,
)
_apply_session_identity(
tool_invocation, agent_invocation.conversation_id
)
handler.start_execute_tool(tool_invocation)
_client_managed_runs[tool_use_id] = tool_invocation
client_managed_runs[tool_use_id] = tool_invocation

# If this is a Task tool, create a SubAgent span under it
# https://platform.claude.com/docs/en/agent-sdk/python#task
Expand Down Expand Up @@ -203,6 +254,10 @@ def _create_tool_spans_from_message(
agent_description=task_description,
input_messages=input_messages,
)
_set_session_id(
subagent_invocation,
agent_invocation.conversation_id,
)

# Start SubAgent span
handler.start_invoke_agent(subagent_invocation)
Expand Down Expand Up @@ -271,6 +326,7 @@ def _process_assistant_message(
handler: ExtendedTelemetryHandler,
collected_messages: List[Dict[str, Any]],
active_task_stack: List[Any],
client_managed_runs: Dict[str, ExecuteToolInvocation],
) -> None:
"""Process AssistantMessage: create LLM turn, extract parts, create tool spans."""
parts = _extract_message_parts(msg)
Expand Down Expand Up @@ -353,7 +409,11 @@ def _process_assistant_message(
turn_tracker.close_llm_turn()

_create_tool_spans_from_message(
msg, handler, agent_invocation, active_task_stack
msg,
handler,
agent_invocation,
active_task_stack,
client_managed_runs,
)


Expand All @@ -363,6 +423,7 @@ def _process_user_message(
handler: ExtendedTelemetryHandler,
collected_messages: List[Dict[str, Any]],
active_task_stack: List[Any],
client_managed_runs: Dict[str, ExecuteToolInvocation],
) -> None:
"""Process UserMessage: close tool spans, collect message content, mark next LLM start."""
user_parts: List[MessagePart] = []
Expand All @@ -376,8 +437,8 @@ def _process_user_message(

if block_type == "ToolResultBlock":
tool_use_id = getattr(block, "tool_use_id", None)
if tool_use_id and tool_use_id in _client_managed_runs:
tool_invocation = _client_managed_runs.pop(tool_use_id)
if tool_use_id and tool_use_id in client_managed_runs:
tool_invocation = client_managed_runs.pop(tool_use_id)

# Set tool response
tool_content = getattr(block, "content", None)
Expand Down Expand Up @@ -533,7 +594,25 @@ def _process_system_message(
if hasattr(msg, "data") and isinstance(msg.data, dict):
session_id = msg.data.get("session_id")
if session_id:
agent_invocation.conversation_id = session_id
_set_session_id(agent_invocation, session_id)


def _process_stream_event_message(
msg: Any,
agent_invocation: InvokeAgentInvocation,
) -> None:
"""Process StreamEvent: extract session_id when streaming mode exposes it early."""
session_id = getattr(msg, "session_id", None)
if not session_id:
event = getattr(msg, "event", None)
if isinstance(event, dict):
session_id = event.get("session_id")

if not session_id:
# Entry baggage is already applied when the agent invocation starts.
return

_set_session_id(agent_invocation, session_id)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Info] _set_session_id is called on every StreamEvent even when session_id is None. When there is no session and no entry baggage, this is effectively a no-op but still triggers a baggage lookup via _entry_baggage_identity_attributes(). For high-throughput streaming paths, consider adding an early return when session_id is None and no entry baggage session is present.



def _process_result_message(
Expand All @@ -543,6 +622,8 @@ def _process_result_message(
) -> None:
"""Process ResultMessage: update session_id (fallback), token usage, and close any open LLM turn."""

_set_session_id(agent_invocation, getattr(msg, "session_id", None))
turn_tracker.set_session_id(agent_invocation.conversation_id)
_update_token_usage(agent_invocation, turn_tracker, msg)

if turn_tracker.current_llm_invocation:
Expand All @@ -554,6 +635,7 @@ async def _process_agent_invocation_stream(
handler: ExtendedTelemetryHandler,
model: str,
prompt: str,
session_id: Optional[str] = None,
) -> Any:
"""Unified handler for processing agent invocation stream.

Expand All @@ -564,18 +646,15 @@ async def _process_agent_invocation_stream(
provider=infer_provider_from_base_url(),
agent_name="claude-agent",
request_model=model,
conversation_id="",
conversation_id=None,
input_messages=[
InputMessage(role="user", parts=[Text(content=prompt)])
]
if prompt
else [],
)
_set_session_id(agent_invocation, session_id)

# Attach empty context to clear any previous context, ensuring each query
# creates an independent root trace. This is important for scenarios where
# multiple queries are called in the same script - each should have its own trace_id.
empty_context_token = otel_context.attach(otel_context.Context())
handler.start_invoke_agent(agent_invocation)

query_start_time = time.time()
Expand All @@ -589,13 +668,16 @@ async def _process_agent_invocation_stream(
# When a Task tool is created, it's pushed here
# When its ToolResultBlock is received, it's popped
active_task_stack: List[Any] = []
client_managed_runs: Dict[str, ExecuteToolInvocation] = {}

try:
async for msg in wrapped_stream:
msg_type = type(msg).__name__

if msg_type == "SystemMessage":
_process_system_message(msg, agent_invocation)
elif msg_type == "StreamEvent":
_process_stream_event_message(msg, agent_invocation)
elif msg_type == "AssistantMessage":
_process_assistant_message(
msg,
Expand All @@ -606,6 +688,7 @@ async def _process_agent_invocation_stream(
handler,
collected_messages,
active_task_stack,
client_managed_runs,
)
elif msg_type == "UserMessage":
_process_user_message(
Expand All @@ -614,6 +697,7 @@ async def _process_agent_invocation_stream(
handler,
collected_messages,
active_task_stack,
client_managed_runs,
)
elif msg_type == "ResultMessage":
_process_result_message(msg, agent_invocation, turn_tracker)
Expand Down Expand Up @@ -648,11 +732,7 @@ async def _process_agent_invocation_stream(
# Span closure failure should not break the application
pass

# Detach empty context token to restore the original context.
# Note: stop_invoke_agent/fail_invoke_agent already detached invocation.context_token,
# which restored to empty context. Now we detach empty_context_token to restore further.
otel_context.detach(empty_context_token)
_clear_client_managed_runs()
_clear_client_managed_runs(handler, client_managed_runs)


class AssistantTurnTracker:
Expand Down Expand Up @@ -728,8 +808,8 @@ def start_llm_turn(
# Add conversation_id (session_id) to LLM span attributes
# This is a custom extension beyond standard GenAI semantic conventions
if agent_invocation and agent_invocation.conversation_id:
llm_invocation.attributes["gen_ai.conversation.id"] = (
agent_invocation.conversation_id
_set_llm_session_id(
llm_invocation, agent_invocation.conversation_id
)

self.handler.start_llm(llm_invocation)
Expand Down Expand Up @@ -774,6 +854,12 @@ def update_usage(
if output_tokens is not None:
target_invocation.output_tokens = output_tokens

def set_session_id(self, session_id: Optional[str]) -> None:
"""Update the open LLM invocation with a late session id."""
target_invocation = self.current_llm_invocation
if target_invocation:
_set_llm_session_id(target_invocation, session_id)

def close_llm_turn(self) -> None:
"""Close the current LLM invocation span."""
if self.current_llm_invocation:
Expand All @@ -798,6 +884,7 @@ def wrap_claude_client_init(wrapped, instance, args, kwargs, handler=None):

instance._otel_handler = handler
instance._otel_prompt = None
instance._otel_session_id = None

return result

Expand All @@ -808,6 +895,10 @@ def wrap_claude_client_query(wrapped, instance, args, kwargs, handler=None):
instance._otel_prompt = str(
kwargs.get("prompt") or (args[0] if args else "")
)
session_id = kwargs.get("session_id")
if session_id is None and len(args) > 1:
session_id = args[1]
instance._otel_session_id = session_id

return wrapped(*args, **kwargs)

Expand Down Expand Up @@ -835,6 +926,7 @@ async def wrap_claude_client_receive_response(
handler=handler,
model=model,
prompt=prompt,
session_id=getattr(instance, "_otel_session_id", None),
):
yield msg

Expand All @@ -852,11 +944,13 @@ async def wrap_query(wrapped, instance, args, kwargs, handler=None):

model = get_model_from_options_or_env(options)
prompt_str = str(prompt) if isinstance(prompt, str) else ""
session_id = getattr(options, "resume", None) if options else None

async for message in _process_agent_invocation_stream(
wrapped(*args, **kwargs),
handler=handler,
model=model,
prompt=prompt_str,
session_id=session_id,
):
yield message
Loading
Loading