diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java index a6dccc44a..7ce69b6da 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java @@ -19,7 +19,6 @@ package org.apache.flink.agents.api.chat.model; import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; @@ -56,22 +55,4 @@ public ResourceType getResourceType() { */ public abstract ChatMessage chat( List messages, List tools, Map arguments); - - /** - * Record token usage metrics for the given model. - * - * @param modelName the name of the model used - * @param promptTokens the number of prompt tokens - * @param completionTokens the number of completion tokens - */ - protected void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) { - FlinkAgentsMetricGroup metricGroup = getMetricGroup(); - if (metricGroup == null) { - return; - } - - FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName); - modelGroup.getCounter("promptTokens").inc(promptTokens); - modelGroup.getCounter("completionTokens").inc(completionTokens); - } } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index af7ed10b2..07c2c81c6 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; @@ -107,6 +108,23 @@ public void open() throws Exception { public abstract Map getParameters(); + /** + * Record token usage metrics for the given model on this setup's bound metric group. + * + * @param modelName the name of the model used + * @param promptTokens the number of prompt tokens + * @param completionTokens the number of completion tokens + */ + public void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) { + FlinkAgentsMetricGroup metricGroup = getMetricGroup(); + if (metricGroup == null) { + return; + } + FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName); + modelGroup.getCounter("promptTokens").inc(promptTokens); + modelGroup.getCounter("completionTokens").inc(completionTokens); + } + public ChatMessage chat(List messages) { return this.chat(messages, Collections.emptyMap()); } @@ -115,8 +133,6 @@ public ChatMessage chat(List messages, Map paramete Preconditions.checkNotNull( connection, "Connection is not initialized. Ensure open() is called before chat()."); - // Pass metric group to connection for token usage tracking - connection.setMetricGroup(getMetricGroup()); // Format input messages if set prompt. if (this.prompt != null) { diff --git a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java index 06cd2b385..c3e5d19ba 100644 --- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java +++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java @@ -67,6 +67,10 @@ public interface RunnerContext { /** * Gets the metric group for Flink Agents. * + *

The returned group must only be accessed from the operator/mailbox (action) thread, not + * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} callable, which runs on + * a separate thread pool. + * * @return the metric group shared across all actions. */ FlinkAgentsMetricGroup getAgentMetricGroup(); @@ -74,6 +78,10 @@ public interface RunnerContext { /** * Gets the individual metric group dedicated for each action. * + *

The returned group must only be accessed from the operator/mailbox (action) thread, not + * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} callable, which runs on + * a separate thread pool. + * * @return the individual metric group specific to the current action. */ FlinkAgentsMetricGroup getActionMetricGroup(); diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java similarity index 59% rename from api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java rename to api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java index 43654944a..cde9f683f 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java @@ -18,71 +18,60 @@ package org.apache.flink.agents.api.chat.model; -import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; -import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.metrics.Counter; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.Collections; -import java.util.List; import java.util.Map; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; -/** Test cases for BaseChatModelConnection token metrics functionality. */ -class BaseChatModelConnectionTokenMetricsTest { +/** Test cases for BaseChatModelSetup token metrics functionality. */ +class BaseChatModelSetupTokenMetricsTest { - private TestChatModelConnection connection; + private TestChatModelSetup setup; private FlinkAgentsMetricGroup mockMetricGroup; private FlinkAgentsMetricGroup mockModelGroup; private Counter mockPromptTokensCounter; private Counter mockCompletionTokensCounter; - /** Test implementation of BaseChatModelConnection for testing purposes. */ - private static class TestChatModelConnection extends BaseChatModelConnection { + /** Test implementation of BaseChatModelSetup for testing purposes. */ + private static class TestChatModelSetup extends BaseChatModelSetup { - public TestChatModelConnection( - ResourceDescriptor descriptor, ResourceContext resourceContext) { + public TestChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) { super(descriptor, resourceContext); } @Override - public ChatMessage chat( - List messages, List tools, Map arguments) { - // Simple test implementation - return new ChatMessage(MessageRole.ASSISTANT, "Test response"); - } - - // Expose protected method for testing - public void testRecordTokenMetrics( - String modelName, long promptTokens, long completionTokens) { - recordTokenMetrics(modelName, promptTokens, completionTokens); + public Map getParameters() { + return Collections.emptyMap(); } } @BeforeEach void setUp() { - connection = - new TestChatModelConnection( + setup = + new TestChatModelSetup( new ResourceDescriptor( - TestChatModelConnection.class.getName(), Collections.emptyMap()), + TestChatModelSetup.class.getName(), Collections.emptyMap()), null); - // Create mock objects mockMetricGroup = mock(FlinkAgentsMetricGroup.class); mockModelGroup = mock(FlinkAgentsMetricGroup.class); mockPromptTokensCounter = mock(Counter.class); mockCompletionTokensCounter = mock(Counter.class); - // Set up mock behavior when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup); when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter); when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter); @@ -91,13 +80,10 @@ void setUp() { @Test @DisplayName("Test token metrics are recorded when metric group is set") void testRecordTokenMetricsWithMetricGroup() { - // Set the metric group - connection.setMetricGroup(mockMetricGroup); + setup.setMetricGroup(mockMetricGroup); - // Record token metrics - connection.testRecordTokenMetrics("gpt-4", 100, 50); + setup.recordTokenMetrics("gpt-4", 100, 50); - // Verify the metrics were recorded verify(mockMetricGroup).getSubGroup("gpt-4"); verify(mockModelGroup).getCounter("promptTokens"); verify(mockModelGroup).getCounter("completionTokens"); @@ -108,22 +94,16 @@ void testRecordTokenMetricsWithMetricGroup() { @Test @DisplayName("Test token metrics are not recorded when metric group is null") void testRecordTokenMetricsWithoutMetricGroup() { - // Do not set metric group (should be null by default) + assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50)); - // Record token metrics - should not throw - assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 100, 50)); - - // No metrics should be recorded verifyNoInteractions(mockMetricGroup); } @Test - @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName -> counters") + @DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> counters") void testTokenMetricsHierarchy() { - // Set the metric group - connection.setMetricGroup(mockMetricGroup); + setup.setMetricGroup(mockMetricGroup); - // Record token metrics for different models FlinkAgentsMetricGroup mockGpt35Group = mock(FlinkAgentsMetricGroup.class); Counter mockGpt35PromptCounter = mock(Counter.class); Counter mockGpt35CompletionCounter = mock(Counter.class); @@ -132,13 +112,9 @@ void testTokenMetricsHierarchy() { when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter); when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter); - // Record for gpt-4 - connection.testRecordTokenMetrics("gpt-4", 100, 50); - - // Record for gpt-3.5-turbo - connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100); + setup.recordTokenMetrics("gpt-4", 100, 50); + setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100); - // Verify each model has its own counters verify(mockMetricGroup).getSubGroup("gpt-4"); verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo"); verify(mockPromptTokensCounter).inc(100); @@ -148,8 +124,8 @@ void testTokenMetricsHierarchy() { } @Test - @DisplayName("Test resource type is CHAT_MODEL_CONNECTION") + @DisplayName("Test resource type is CHAT_MODEL") void testResourceType() { - assertEquals(ResourceType.CHAT_MODEL_CONNECTION, connection.getResourceType()); + assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType()); } } diff --git a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java index 248b464ea..93691d3fe 100644 --- a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java +++ b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java @@ -133,7 +133,7 @@ public ChatMessage chat( Message response = client.messages().create(params); ChatMessage result = convertResponse(response, jsonPrefillApplied); - // Record token metrics + // Stash token usage String modelName = null; if (arguments != null && arguments.get("model") != null) { modelName = arguments.get("model").toString(); @@ -142,8 +142,9 @@ public ChatMessage chat( modelName = this.defaultModel; } if (modelName != null && !modelName.isBlank()) { - recordTokenMetrics( - modelName, response.usage().inputTokens(), response.usage().outputTokens()); + result.getExtraArgs().put("model_name", modelName); + result.getExtraArgs().put("promptTokens", response.usage().inputTokens()); + result.getExtraArgs().put("completionTokens", response.usage().outputTokens()); } return result; diff --git a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java index 3051ecf47..318b54576 100644 --- a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java +++ b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java @@ -187,12 +187,15 @@ public ChatMessage chat( chatMessage.setToolCalls(convertedToolCalls); } - // Record token metrics if model name is available + // Stash token usage if model name is available if (modelName != null && !modelName.isBlank()) { CompletionsUsage usage = completions.getUsage(); if (usage != null) { - recordTokenMetrics( - modelName, usage.getPromptTokens(), usage.getCompletionTokens()); + chatMessage.getExtraArgs().put("model_name", modelName); + chatMessage.getExtraArgs().put("promptTokens", (long) usage.getPromptTokens()); + chatMessage + .getExtraArgs() + .put("completionTokens", (long) usage.getCompletionTokens()); } } diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java index 8327795a3..58d235083 100644 --- a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java @@ -178,12 +178,14 @@ public ChatMessage chat( ConverseResponse response = retryExecutor.execute(() -> client.converse(request), "BedrockConverse"); + ChatMessage result = convertResponse(response); if (response.usage() != null) { - recordTokenMetrics( - modelId, response.usage().inputTokens(), response.usage().outputTokens()); + result.getExtraArgs().put("model_name", modelId); + result.getExtraArgs().put("promptTokens", response.usage().inputTokens().longValue()); + result.getExtraArgs() + .put("completionTokens", response.usage().outputTokens().longValue()); } - - return convertResponse(response); + return result; } private static boolean isRetryable(Exception e) { diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 2cda1ea40..4c617455c 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -224,13 +224,14 @@ public ChatMessage chat( chatMessage.setToolCalls(toolCalls); } - // Record token metrics if model name is available + // Stash token usage if model name is available if (modelName != null && !modelName.isBlank()) { Integer promptTokens = ollamaChatResponse.getPromptEvalCount(); Integer completionTokens = ollamaChatResponse.getEvalCount(); if (promptTokens != null && completionTokens != null) { - recordTokenMetrics( - modelName, promptTokens.longValue(), completionTokens.longValue()); + extraArgs.put("model_name", modelName); + extraArgs.put("promptTokens", promptTokens.longValue()); + extraArgs.put("completionTokens", completionTokens.longValue()); } } diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java index 7d6b5c2c1..6567bd2be 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java @@ -216,10 +216,11 @@ public ChatMessage chat( if (modelOfAzureDeployment != null && !modelOfAzureDeployment.isBlank() && completion.usage().isPresent()) { - recordTokenMetrics( - modelOfAzureDeployment, - completion.usage().get().promptTokens(), - completion.usage().get().completionTokens()); + response.getExtraArgs().put("model_name", modelOfAzureDeployment); + response.getExtraArgs() + .put("promptTokens", completion.usage().get().promptTokens()); + response.getExtraArgs() + .put("completionTokens", completion.usage().get().completionTokens()); } return response; diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java index e4947e8f3..2a0b78fea 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java @@ -129,17 +129,18 @@ public ChatMessage chat( OpenAIChatCompletionsUtils.convertFromOpenAIMessage( completion.choices().get(0).message()); - // Record token metrics + // Stash token usage if (completion.usage().isPresent()) { String modelName = arguments != null ? (String) arguments.get("model") : null; if (modelName == null || modelName.isBlank()) { modelName = this.defaultModel; } if (modelName != null && !modelName.isBlank()) { - recordTokenMetrics( - modelName, - completion.usage().get().promptTokens(), - completion.usage().get().completionTokens()); + response.getExtraArgs().put("model_name", modelName); + response.getExtraArgs() + .put("promptTokens", completion.usage().get().promptTokens()); + response.getExtraArgs() + .put("completionTokens", completion.usage().get().completionTokens()); } } diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java index 9b0d143eb..00b5f9b68 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java @@ -140,10 +140,10 @@ public ChatMessage chat( modelName = this.defaultModel; } if (modelName != null && !modelName.isBlank()) { - recordTokenMetrics( - modelName, - response.usage().get().inputTokens(), - response.usage().get().outputTokens()); + result.getExtraArgs().put("model_name", modelName); + result.getExtraArgs().put("promptTokens", response.usage().get().inputTokens()); + result.getExtraArgs() + .put("completionTokens", response.usage().get().outputTokens()); } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index 997fb28b9..203ddd7e8 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -179,6 +179,23 @@ private static void recordRetryMetrics( } } + static void recordChatTokenMetrics(BaseChatModelSetup chatModel, ChatMessage response) { + Map extraArgs = response.getExtraArgs(); + Object modelName = extraArgs.get("model_name"); + Object promptTokens = extraArgs.get("promptTokens"); + Object completionTokens = extraArgs.get("completionTokens"); + if (modelName != null + && !modelName.toString().isEmpty() + && promptTokens instanceof Number + && completionTokens instanceof Number) { + long prompt = ((Number) promptTokens).longValue(); + long completion = ((Number) completionTokens).longValue(); + if (prompt > 0 && completion > 0) { + chatModel.recordTokenMetrics(modelName.toString(), prompt, completion); + } + } + } + private static void handleToolCalls( ChatMessage response, UUID initialRequestId, @@ -349,6 +366,7 @@ public ChatMessage call() throws Exception { chatAsync ? ctx.durableExecuteAsync(callable) : ctx.durableExecute(callable); + recordChatTokenMetrics(chatModel, response); // only generate structured output for final response. if (outputSchema != null && response.getToolCalls().isEmpty()) { response = generateStructuredOutput(response, outputSchema); diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java index d7f117850..85c263a66 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java @@ -17,13 +17,98 @@ */ package org.apache.flink.agents.plan.actions; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; import org.junit.jupiter.api.Test; +import java.util.HashMap; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; /** Tests for {@link ChatModelAction}. */ class ChatModelActionTest { + private static ChatMessage responseWith(Map extraArgs) { + return new ChatMessage(MessageRole.ASSISTANT, "response", extraArgs); + } + + @Test + void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("promptTokens", 100L); + extraArgs.put("completionTokens", 50L); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup).recordTokenMetrics("m", 100L, 50L); + } + + @Test + void testRecordChatTokenMetricsHandlesIntegerTokenValues() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("promptTokens", 100); + extraArgs.put("completionTokens", 50); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup).recordTokenMetrics("m", 100L, 50L); + } + + @Test + void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("promptTokens", "100"); + extraArgs.put("completionTokens", 50L); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + } + + @Test + void testRecordChatTokenMetricsSkipsWhenKeyMissing() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("completionTokens", 50L); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + } + + @Test + void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + + Map zeroPrompt = new HashMap<>(); + zeroPrompt.put("model_name", "m"); + zeroPrompt.put("promptTokens", 0L); + zeroPrompt.put("completionTokens", 50L); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(zeroPrompt)); + + Map emptyModel = new HashMap<>(); + emptyModel.put("model_name", ""); + emptyModel.put("promptTokens", 100L); + emptyModel.put("completionTokens", 50L); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(emptyModel)); + + verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + } + @Test void testCleanLlmResponseWithJsonBlock() { String input = "```json\n{\"key\": \"value\"}\n```";