Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/metrics/efficiency_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ def compute_token_metrics(messages: List[Dict[str, Any]]) -> Dict[str, float]:

Returns:
Dictionary with keys:
- total_tokens: Total tokens (prompt + response)
- total_prompt_tokens: Total prompt tokens
- total_tokens: Total tokens (latest prompt + all responses)
- total_prompt_tokens: Prompt tokens from the latest TokenEvent
- total_response_tokens: Total response/completion tokens
- avg_prompt_tokens_per_step: Average prompt tokens per TokenEvent
- avg_response_tokens_per_step: Average response tokens per TokenEvent

TokenEvent prompt_token_ids are cumulative across turns, so summing every
prompt would double count earlier trajectory context.
"""
token_messages = [msg for msg in messages if msg.get("kind") == "TokenEvent"]

Expand Down Expand Up @@ -77,7 +80,7 @@ def compute_tool_call_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
- avg_tool_calls_per_step: Average tool calls per step
- tool_call_breakdown: Dict mapping tool names to counts
"""
# Find all assistant messages with tool calls
# Count SDK ActionEvents and OpenAI-style assistant tool calls.
tool_call_count = 0
tool_breakdown = {}
if not messages:
Expand All @@ -88,6 +91,12 @@ def compute_tool_call_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
}

for msg in messages:
if msg.get("kind") == "ActionEvent":
tool_name = msg.get("tool_name") or "unknown"
tool_call_count += 1
tool_breakdown[tool_name] = tool_breakdown.get(tool_name, 0) + 1
continue

# Assistant messages contain tool_calls field
if msg.get("role") == "assistant" and "tool_calls" in msg:
tool_calls = msg.get("tool_calls", [])
Expand Down
72 changes: 52 additions & 20 deletions tests/metrics/test_efficiency_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@
"content": "Search results here",
}

MOCK_ACTION_EVENT_BASH = {
"kind": "ActionEvent",
"tool_name": "bash",
"llm_response_id": "turn-1",
}

MOCK_ACTION_EVENT_RESULT = {
"kind": "ActionEvent",
"tool_name": "result",
"llm_response_id": "turn-1",
}

MOCK_ACTION_EVENT_UNKNOWN = {
"kind": "ActionEvent",
"llm_response_id": "turn-1",
}


class TestComputeTokenMetrics:
"""Tests for compute_token_metrics function."""
Expand Down Expand Up @@ -84,10 +101,10 @@ def test_multiple_token_events(self):
"""Test with multiple TokenEvents."""
messages = [MOCK_TOKEN_MESSAGE_1, MOCK_TOKEN_MESSAGE_2]
result = compute_token_metrics(messages)
assert result["total_tokens"] == 15 # (4+3) + (6+2)
assert result["total_prompt_tokens"] == 10 # 4 + 6
assert result["total_tokens"] == 11 # latest prompt + all responses
assert result["total_prompt_tokens"] == 6
assert result["total_response_tokens"] == 5 # 3 + 2
assert result["avg_prompt_tokens_per_step"] == 5.0 # 10/2
assert result["avg_prompt_tokens_per_step"] == 3.0 # 6/2
assert result["avg_response_tokens_per_step"] == 2.5 # 5/2

def test_mixed_messages(self):
Expand All @@ -100,8 +117,8 @@ def test_mixed_messages(self):
]
result = compute_token_metrics(messages)
# Should only count the two TokenEvents
assert result["total_tokens"] == 15
assert result["total_prompt_tokens"] == 10
assert result["total_tokens"] == 11
assert result["total_prompt_tokens"] == 6
assert result["total_response_tokens"] == 5


Expand Down Expand Up @@ -190,6 +207,26 @@ def test_mixed_assistant_messages(self):
assert result["total_tool_calls"] == 2
assert result["avg_tool_calls_per_step"] == 1.0 # 2 tools / 2 steps

def test_action_event_tool_calls(self):
"""Test tool counting from SDK ActionEvent trajectory messages."""
messages = [
MOCK_TOKEN_MESSAGE_1,
MOCK_ACTION_EVENT_BASH,
MOCK_ACTION_EVENT_RESULT,
]
result = compute_tool_call_metrics(messages)
assert result["total_tool_calls"] == 2
assert result["avg_tool_calls_per_step"] == 2.0
assert result["tool_call_breakdown"]["bash"] == 1
assert result["tool_call_breakdown"]["result"] == 1

def test_action_event_without_tool_name(self):
"""Test ActionEvent fallback when tool_name is absent."""
messages = [MOCK_TOKEN_MESSAGE_1, MOCK_ACTION_EVENT_UNKNOWN]
result = compute_tool_call_metrics(messages)
assert result["total_tool_calls"] == 1
assert result["tool_call_breakdown"]["unknown"] == 1


class TestComputeAllEfficiencyMetrics:
"""Tests for compute_all_efficiency_metrics function."""
Expand All @@ -209,9 +246,11 @@ def test_complete_trajectory(self):
"""Test with complete trajectory including tokens, steps, and tools."""
messages = [
MOCK_TOKEN_MESSAGE_1, # Step 1: 7 tokens
MOCK_ASSISTANT_MESSAGE_WITH_TOOLS, # 2 tool calls
MOCK_ACTION_EVENT_BASH,
MOCK_ACTION_EVENT_RESULT,
MOCK_TOKEN_MESSAGE_2, # Step 2: 8 tokens
MOCK_ASSISTANT_MESSAGE_WITH_TOOLS, # 2 tool calls
MOCK_ACTION_EVENT_BASH,
MOCK_ACTION_EVENT_RESULT,
]
result = compute_all_efficiency_metrics(
messages=messages,
Expand All @@ -221,23 +260,16 @@ def test_complete_trajectory(self):
)

# Check core metrics
assert result["tokens"] == 15 # 7 + 8
assert result["tokens"] == 11 # latest prompt + all responses
assert result["steps"] == 2
assert result["avg_tool_calls_per_step"] == 2.0 # 4 tools / 2 steps
assert result["wall_clock_duration"] == 15.5

# Check timestamps
assert result["start_timestamp"] == "2025-01-01T10:00:00"
assert result["end_timestamp"] == "2025-01-01T10:00:15"

# Check token breakdown
assert result["token_breakdown"]["total_prompt_tokens"] == 10
assert result["token_breakdown"]["total_response_tokens"] == 5

# Check tool breakdown
assert result["tool_breakdown"]["total_tool_calls"] == 4
assert result["tool_breakdown"]["by_tool_type"]["bash"] == 2
assert result["tool_breakdown"]["by_tool_type"]["result"] == 2
# Check flattened token fields
assert result["total_prompt_tokens"] == 6
assert result["total_response_tokens"] == 5
assert result["avg_prompt_tokens_per_step"] == 3.0
assert result["avg_response_tokens_per_step"] == 2.5

def test_without_timestamps(self):
"""Test that timestamps are optional."""
Expand Down