From b9ccc185dab7c9f905bb5d7ff7fa9578d95205f4 Mon Sep 17 00:00:00 2001 From: DB825 Date: Sat, 6 Jun 2026 21:57:13 -0500 Subject: [PATCH] Count ActionEvent tool calls in efficiency metrics --- src/metrics/efficiency_metrics.py | 15 ++++- tests/metrics/test_efficiency_metrics.py | 72 +++++++++++++++++------- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/metrics/efficiency_metrics.py b/src/metrics/efficiency_metrics.py index 3074eb2..4f4bdd8 100644 --- a/src/metrics/efficiency_metrics.py +++ b/src/metrics/efficiency_metrics.py @@ -12,11 +12,14 @@ def compute_token_metrics(messages: List[Dict[str, Any]]) -> Dict[str, float]: Returns: Dictionary with keys: - - total_tokens: Total tokens (prompt + response) - - total_prompt_tokens: Total prompt tokens + - total_tokens: Total tokens (latest prompt + all responses) + - total_prompt_tokens: Prompt tokens from the latest TokenEvent - total_response_tokens: Total response/completion tokens - avg_prompt_tokens_per_step: Average prompt tokens per TokenEvent - avg_response_tokens_per_step: Average response tokens per TokenEvent + + TokenEvent prompt_token_ids are cumulative across turns, so summing every + prompt would double count earlier trajectory context. """ token_messages = [msg for msg in messages if msg.get("kind") == "TokenEvent"] @@ -77,7 +80,7 @@ def compute_tool_call_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]: - avg_tool_calls_per_step: Average tool calls per step - tool_call_breakdown: Dict mapping tool names to counts """ - # Find all assistant messages with tool calls + # Count SDK ActionEvents and OpenAI-style assistant tool calls. tool_call_count = 0 tool_breakdown = {} if not messages: @@ -88,6 +91,12 @@ def compute_tool_call_metrics(messages: List[Dict[str, Any]]) -> Dict[str, Any]: } for msg in messages: + if msg.get("kind") == "ActionEvent": + tool_name = msg.get("tool_name") or "unknown" + tool_call_count += 1 + tool_breakdown[tool_name] = tool_breakdown.get(tool_name, 0) + 1 + continue + # Assistant messages contain tool_calls field if msg.get("role") == "assistant" and "tool_calls" in msg: tool_calls = msg.get("tool_calls", []) diff --git a/tests/metrics/test_efficiency_metrics.py b/tests/metrics/test_efficiency_metrics.py index ad80299..7c3ef28 100644 --- a/tests/metrics/test_efficiency_metrics.py +++ b/tests/metrics/test_efficiency_metrics.py @@ -51,6 +51,23 @@ "content": "Search results here", } +MOCK_ACTION_EVENT_BASH = { + "kind": "ActionEvent", + "tool_name": "bash", + "llm_response_id": "turn-1", +} + +MOCK_ACTION_EVENT_RESULT = { + "kind": "ActionEvent", + "tool_name": "result", + "llm_response_id": "turn-1", +} + +MOCK_ACTION_EVENT_UNKNOWN = { + "kind": "ActionEvent", + "llm_response_id": "turn-1", +} + class TestComputeTokenMetrics: """Tests for compute_token_metrics function.""" @@ -84,10 +101,10 @@ def test_multiple_token_events(self): """Test with multiple TokenEvents.""" messages = [MOCK_TOKEN_MESSAGE_1, MOCK_TOKEN_MESSAGE_2] result = compute_token_metrics(messages) - assert result["total_tokens"] == 15 # (4+3) + (6+2) - assert result["total_prompt_tokens"] == 10 # 4 + 6 + assert result["total_tokens"] == 11 # latest prompt + all responses + assert result["total_prompt_tokens"] == 6 assert result["total_response_tokens"] == 5 # 3 + 2 - assert result["avg_prompt_tokens_per_step"] == 5.0 # 10/2 + assert result["avg_prompt_tokens_per_step"] == 3.0 # 6/2 assert result["avg_response_tokens_per_step"] == 2.5 # 5/2 def test_mixed_messages(self): @@ -100,8 +117,8 @@ def test_mixed_messages(self): ] result = compute_token_metrics(messages) # Should only count the two TokenEvents - assert result["total_tokens"] == 15 - assert result["total_prompt_tokens"] == 10 + assert result["total_tokens"] == 11 + assert result["total_prompt_tokens"] == 6 assert result["total_response_tokens"] == 5 @@ -190,6 +207,26 @@ def test_mixed_assistant_messages(self): assert result["total_tool_calls"] == 2 assert result["avg_tool_calls_per_step"] == 1.0 # 2 tools / 2 steps + def test_action_event_tool_calls(self): + """Test tool counting from SDK ActionEvent trajectory messages.""" + messages = [ + MOCK_TOKEN_MESSAGE_1, + MOCK_ACTION_EVENT_BASH, + MOCK_ACTION_EVENT_RESULT, + ] + result = compute_tool_call_metrics(messages) + assert result["total_tool_calls"] == 2 + assert result["avg_tool_calls_per_step"] == 2.0 + assert result["tool_call_breakdown"]["bash"] == 1 + assert result["tool_call_breakdown"]["result"] == 1 + + def test_action_event_without_tool_name(self): + """Test ActionEvent fallback when tool_name is absent.""" + messages = [MOCK_TOKEN_MESSAGE_1, MOCK_ACTION_EVENT_UNKNOWN] + result = compute_tool_call_metrics(messages) + assert result["total_tool_calls"] == 1 + assert result["tool_call_breakdown"]["unknown"] == 1 + class TestComputeAllEfficiencyMetrics: """Tests for compute_all_efficiency_metrics function.""" @@ -209,9 +246,11 @@ def test_complete_trajectory(self): """Test with complete trajectory including tokens, steps, and tools.""" messages = [ MOCK_TOKEN_MESSAGE_1, # Step 1: 7 tokens - MOCK_ASSISTANT_MESSAGE_WITH_TOOLS, # 2 tool calls + MOCK_ACTION_EVENT_BASH, + MOCK_ACTION_EVENT_RESULT, MOCK_TOKEN_MESSAGE_2, # Step 2: 8 tokens - MOCK_ASSISTANT_MESSAGE_WITH_TOOLS, # 2 tool calls + MOCK_ACTION_EVENT_BASH, + MOCK_ACTION_EVENT_RESULT, ] result = compute_all_efficiency_metrics( messages=messages, @@ -221,23 +260,16 @@ def test_complete_trajectory(self): ) # Check core metrics - assert result["tokens"] == 15 # 7 + 8 + assert result["tokens"] == 11 # latest prompt + all responses assert result["steps"] == 2 assert result["avg_tool_calls_per_step"] == 2.0 # 4 tools / 2 steps assert result["wall_clock_duration"] == 15.5 - # Check timestamps - assert result["start_timestamp"] == "2025-01-01T10:00:00" - assert result["end_timestamp"] == "2025-01-01T10:00:15" - - # Check token breakdown - assert result["token_breakdown"]["total_prompt_tokens"] == 10 - assert result["token_breakdown"]["total_response_tokens"] == 5 - - # Check tool breakdown - assert result["tool_breakdown"]["total_tool_calls"] == 4 - assert result["tool_breakdown"]["by_tool_type"]["bash"] == 2 - assert result["tool_breakdown"]["by_tool_type"]["result"] == 2 + # Check flattened token fields + assert result["total_prompt_tokens"] == 6 + assert result["total_response_tokens"] == 5 + assert result["avg_prompt_tokens_per_step"] == 3.0 + assert result["avg_response_tokens_per_step"] == 2.5 def test_without_timestamps(self): """Test that timestamps are optional."""