From e5034cea30363445b709c38dfe8fd504924072c4 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Mon, 1 Jun 2026 19:22:25 +0800 Subject: [PATCH] [api][plan][integrations] Record built-in chat token metrics outside the async call boundary (#712) Backport of #712 to release-0.2 with scope narrowed to the chat-model connections that exist on this branch. The Bedrock, AzureOpenAI and OpenAIResponses connection variants are not present on release-0.2 and are intentionally excluded. Move token-metric recording from the durable async callable (where it crossed the operator/mailbox thread boundary) to the action thread: - BaseChatModelSetup gains public recordTokenMetrics(String, long, long). - BaseChatModelConnection.recordTokenMetrics(...) and the connection.setMetricGroup(...) forwarding in BaseChatModelSetup are removed. - Each connection's chat() stashes model_name / promptTokens / completionTokens into ChatMessage.extraArgs (Ollama, Anthropic, AzureAI, OpenAI on release-0.2). - ChatModelAction records via the setup after durableExecute(Async) returns, before structured-output reassignment. - RunnerContext.getAgentMetricGroup/getActionMetricGroup javadoc notes that the returned group must only be accessed from the operator thread, not inside a durable callable. Emitted metric paths and counter names are unchanged. Records are gated identically to Python: non-empty model name and both token counts greater than zero; Integer/Long token values are accepted via Number#longValue(). Tests: - BaseChatModelConnectionTokenMetricsTest renamed and rewritten to BaseChatModelSetupTokenMetricsTest (target moved from connection to setup). - New ChatModelActionTest covers recordChatTokenMetrics: records when all keys present and positive; Integer-typed values still recorded; skips on missing key, non-numeric value, zero token, or empty model name. --- .../chat/model/BaseChatModelConnection.java | 19 --- .../api/chat/model/BaseChatModelSetup.java | 21 +++- .../agents/api/context/RunnerContext.java | 8 ++ ...> BaseChatModelSetupTokenMetricsTest.java} | 75 +++++------- .../AnthropicChatModelConnection.java | 7 +- .../azureai/AzureAIChatModelConnection.java | 9 +- .../ollama/OllamaChatModelConnection.java | 7 +- .../openai/OpenAIChatModelConnection.java | 11 +- .../agents/plan/actions/ChatModelAction.java | 18 +++ .../plan/actions/ChatModelActionTest.java | 110 ++++++++++++++++++ 10 files changed, 200 insertions(+), 85 deletions(-) rename api/src/test/java/org/apache/flink/agents/api/chat/model/{BaseChatModelConnectionTokenMetricsTest.java => BaseChatModelSetupTokenMetricsTest.java} (61%) create mode 100644 plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java index 7b70af0c3..7254e37f3 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java @@ -19,7 +19,6 @@ package org.apache.flink.agents.api.chat.model; import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; @@ -57,22 +56,4 @@ public ResourceType getResourceType() { */ public abstract ChatMessage chat( List messages, List tools, Map arguments); - - /** - * Record token usage metrics for the given model. - * - * @param modelName the name of the model used - * @param promptTokens the number of prompt tokens - * @param completionTokens the number of completion tokens - */ - protected void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) { - FlinkAgentsMetricGroup metricGroup = getMetricGroup(); - if (metricGroup == null) { - return; - } - - FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName); - modelGroup.getCounter("promptTokens").inc(promptTokens); - modelGroup.getCounter("completionTokens").inc(completionTokens); - } } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index 15aa953e3..343dd5a05 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; @@ -51,6 +52,23 @@ public BaseChatModelSetup( public abstract Map getParameters(); + /** + * Record token usage metrics for the given model on this setup's bound metric group. + * + * @param modelName the name of the model used + * @param promptTokens the number of prompt tokens + * @param completionTokens the number of completion tokens + */ + public void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) { + FlinkAgentsMetricGroup metricGroup = getMetricGroup(); + if (metricGroup == null) { + return; + } + FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName); + modelGroup.getCounter("promptTokens").inc(promptTokens); + modelGroup.getCounter("completionTokens").inc(completionTokens); + } + public ChatMessage chat(List messages) { return this.chat(messages, Collections.emptyMap()); } @@ -60,9 +78,6 @@ public ChatMessage chat(List messages, Map paramete (BaseChatModelConnection) this.getResource.apply(this.connection, ResourceType.CHAT_MODEL_CONNECTION); - // Pass metric group to connection for token usage tracking - connection.setMetricGroup(getMetricGroup()); - // Format input messages if set prompt. if (this.prompt != null) { if (this.prompt instanceof String) { diff --git a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java index 0810752a7..65f84b7aa 100644 --- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java +++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java @@ -67,6 +67,10 @@ public interface RunnerContext { /** * Gets the metric group for Flink Agents. * + *

The returned group must only be accessed from the operator/mailbox (action) thread, not + * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} callable, which runs on + * a separate thread pool. + * * @return the metric group shared across all actions. */ FlinkAgentsMetricGroup getAgentMetricGroup(); @@ -74,6 +78,10 @@ public interface RunnerContext { /** * Gets the individual metric group dedicated for each action. * + *

The returned group must only be accessed from the operator/mailbox (action) thread, not + * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} callable, which runs on + * a separate thread pool. + * * @return the individual metric group specific to the current action. */ FlinkAgentsMetricGroup getActionMetricGroup(); diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java similarity index 61% rename from api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java rename to api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java index 53c9bc6cc..8d98febff 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java @@ -18,73 +18,63 @@ package org.apache.flink.agents.api.chat.model; -import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; -import org.apache.flink.agents.api.tools.Tool; import org.apache.flink.metrics.Counter; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; -/** Test cases for BaseChatModelConnection token metrics functionality. */ -class BaseChatModelConnectionTokenMetricsTest { +/** Test cases for BaseChatModelSetup token metrics functionality. */ +class BaseChatModelSetupTokenMetricsTest { - private TestChatModelConnection connection; + private TestChatModelSetup setup; private FlinkAgentsMetricGroup mockMetricGroup; private FlinkAgentsMetricGroup mockModelGroup; private Counter mockPromptTokensCounter; private Counter mockCompletionTokensCounter; - /** Test implementation of BaseChatModelConnection for testing purposes. */ - private static class TestChatModelConnection extends BaseChatModelConnection { + /** Test implementation of BaseChatModelSetup for testing purposes. */ + private static class TestChatModelSetup extends BaseChatModelSetup { - public TestChatModelConnection( + public TestChatModelSetup( ResourceDescriptor descriptor, BiFunction getResource) { super(descriptor, getResource); } @Override - public ChatMessage chat( - List messages, List tools, Map arguments) { - // Simple test implementation - return new ChatMessage(MessageRole.ASSISTANT, "Test response"); - } - - // Expose protected method for testing - public void testRecordTokenMetrics( - String modelName, long promptTokens, long completionTokens) { - recordTokenMetrics(modelName, promptTokens, completionTokens); + public Map getParameters() { + return Collections.emptyMap(); } } @BeforeEach void setUp() { - connection = - new TestChatModelConnection( + setup = + new TestChatModelSetup( new ResourceDescriptor( - TestChatModelConnection.class.getName(), Collections.emptyMap()), + TestChatModelSetup.class.getName(), Collections.emptyMap()), null); - // Create mock objects mockMetricGroup = mock(FlinkAgentsMetricGroup.class); mockModelGroup = mock(FlinkAgentsMetricGroup.class); mockPromptTokensCounter = mock(Counter.class); mockCompletionTokensCounter = mock(Counter.class); - // Set up mock behavior when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup); when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter); when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter); @@ -93,13 +83,10 @@ void setUp() { @Test @DisplayName("Test token metrics are recorded when metric group is set") void testRecordTokenMetricsWithMetricGroup() { - // Set the metric group - connection.setMetricGroup(mockMetricGroup); + setup.setMetricGroup(mockMetricGroup); - // Record token metrics - connection.testRecordTokenMetrics("gpt-4", 100, 50); + setup.recordTokenMetrics("gpt-4", 100, 50); - // Verify the metrics were recorded verify(mockMetricGroup).getSubGroup("gpt-4"); verify(mockModelGroup).getCounter("promptTokens"); verify(mockModelGroup).getCounter("completionTokens"); @@ -110,22 +97,16 @@ void testRecordTokenMetricsWithMetricGroup() { @Test @DisplayName("Test token metrics are not recorded when metric group is null") void testRecordTokenMetricsWithoutMetricGroup() { - // Do not set metric group (should be null by default) + assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50)); - // Record token metrics - should not throw - assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 100, 50)); - - // No metrics should be recorded verifyNoInteractions(mockMetricGroup); } @Test - @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName -> counters") + @DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> counters") void testTokenMetricsHierarchy() { - // Set the metric group - connection.setMetricGroup(mockMetricGroup); + setup.setMetricGroup(mockMetricGroup); - // Record token metrics for different models FlinkAgentsMetricGroup mockGpt35Group = mock(FlinkAgentsMetricGroup.class); Counter mockGpt35PromptCounter = mock(Counter.class); Counter mockGpt35CompletionCounter = mock(Counter.class); @@ -134,13 +115,9 @@ void testTokenMetricsHierarchy() { when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter); when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter); - // Record for gpt-4 - connection.testRecordTokenMetrics("gpt-4", 100, 50); - - // Record for gpt-3.5-turbo - connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100); + setup.recordTokenMetrics("gpt-4", 100, 50); + setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100); - // Verify each model has its own counters verify(mockMetricGroup).getSubGroup("gpt-4"); verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo"); verify(mockPromptTokensCounter).inc(100); @@ -150,8 +127,8 @@ void testTokenMetricsHierarchy() { } @Test - @DisplayName("Test resource type is CHAT_MODEL_CONNECTION") + @DisplayName("Test resource type is CHAT_MODEL") void testResourceType() { - assertEquals(ResourceType.CHAT_MODEL_CONNECTION, connection.getResourceType()); + assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType()); } } diff --git a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java index 6dded957a..f713654ed 100644 --- a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java +++ b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java @@ -135,7 +135,7 @@ public ChatMessage chat( Message response = client.messages().create(params); ChatMessage result = convertResponse(response, jsonPrefillApplied); - // Record token metrics + // Stash token usage String modelName = null; if (arguments != null && arguments.get("model") != null) { modelName = arguments.get("model").toString(); @@ -144,8 +144,9 @@ public ChatMessage chat( modelName = this.defaultModel; } if (modelName != null && !modelName.isBlank()) { - recordTokenMetrics( - modelName, response.usage().inputTokens(), response.usage().outputTokens()); + result.getExtraArgs().put("model_name", modelName); + result.getExtraArgs().put("promptTokens", response.usage().inputTokens()); + result.getExtraArgs().put("completionTokens", response.usage().outputTokens()); } return result; diff --git a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java index f223ee553..7afa68054 100644 --- a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java +++ b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java @@ -189,12 +189,15 @@ public ChatMessage chat( chatMessage.setToolCalls(convertedToolCalls); } - // Record token metrics if model name is available + // Stash token usage if model name is available if (modelName != null && !modelName.isBlank()) { CompletionsUsage usage = completions.getUsage(); if (usage != null) { - recordTokenMetrics( - modelName, usage.getPromptTokens(), usage.getCompletionTokens()); + chatMessage.getExtraArgs().put("model_name", modelName); + chatMessage.getExtraArgs().put("promptTokens", (long) usage.getPromptTokens()); + chatMessage + .getExtraArgs() + .put("completionTokens", (long) usage.getCompletionTokens()); } } diff --git a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java index 773069f4c..4071b51c0 100644 --- a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java +++ b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java @@ -227,13 +227,14 @@ public ChatMessage chat( chatMessage.setToolCalls(toolCalls); } - // Record token metrics if model name is available + // Stash token usage if model name is available if (modelName != null && !modelName.isBlank()) { Integer promptTokens = ollamaChatResponse.getPromptEvalCount(); Integer completionTokens = ollamaChatResponse.getEvalCount(); if (promptTokens != null && completionTokens != null) { - recordTokenMetrics( - modelName, promptTokens.longValue(), completionTokens.longValue()); + extraArgs.put("model_name", modelName); + extraArgs.put("promptTokens", promptTokens.longValue()); + extraArgs.put("completionTokens", completionTokens.longValue()); } } diff --git a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java index b04cd2b26..039963d3a 100644 --- a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java +++ b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatModelConnection.java @@ -144,17 +144,18 @@ public ChatMessage chat( ChatCompletion completion = client.chat().completions().create(params); ChatMessage response = convertResponse(completion); - // Record token metrics + // Stash token usage if (completion.usage().isPresent()) { String modelName = arguments != null ? (String) arguments.get("model") : null; if (modelName == null || modelName.isBlank()) { modelName = this.defaultModel; } if (modelName != null && !modelName.isBlank()) { - recordTokenMetrics( - modelName, - completion.usage().get().promptTokens(), - completion.usage().get().completionTokens()); + response.getExtraArgs().put("model_name", modelName); + response.getExtraArgs() + .put("promptTokens", completion.usage().get().promptTokens()); + response.getExtraArgs() + .put("completionTokens", completion.usage().get().completionTokens()); } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index cc73f2c62..276c232f9 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -129,6 +129,23 @@ private static Map getToolRequestEventContext( return (Map) toolRequestEventContext.remove(requestId); } + static void recordChatTokenMetrics(BaseChatModelSetup chatModel, ChatMessage response) { + Map extraArgs = response.getExtraArgs(); + Object modelName = extraArgs.get("model_name"); + Object promptTokens = extraArgs.get("promptTokens"); + Object completionTokens = extraArgs.get("completionTokens"); + if (modelName != null + && !modelName.toString().isEmpty() + && promptTokens instanceof Number + && completionTokens instanceof Number) { + long prompt = ((Number) promptTokens).longValue(); + long completion = ((Number) completionTokens).longValue(); + if (prompt > 0 && completion > 0) { + chatModel.recordTokenMetrics(modelName.toString(), prompt, completion); + } + } + } + private static void handleToolCalls( ChatMessage response, UUID initialRequestId, @@ -239,6 +256,7 @@ public ChatMessage call() throws Exception { chatAsync ? ctx.durableExecuteAsync(callable) : ctx.durableExecute(callable); + recordChatTokenMetrics(chatModel, response); // only generate structured output for final response. if (outputSchema != null && response.getToolCalls().isEmpty()) { response = generateStructuredOutput(response, outputSchema); diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java new file mode 100644 index 000000000..b4c7c23e2 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.actions; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** Tests for {@link ChatModelAction}. */ +class ChatModelActionTest { + + private static ChatMessage responseWith(Map extraArgs) { + return new ChatMessage(MessageRole.ASSISTANT, "response", extraArgs); + } + + @Test + void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("promptTokens", 100L); + extraArgs.put("completionTokens", 50L); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup).recordTokenMetrics("m", 100L, 50L); + } + + @Test + void testRecordChatTokenMetricsHandlesIntegerTokenValues() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("promptTokens", 100); + extraArgs.put("completionTokens", 50); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup).recordTokenMetrics("m", 100L, 50L); + } + + @Test + void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("promptTokens", "100"); + extraArgs.put("completionTokens", 50L); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + } + + @Test + void testRecordChatTokenMetricsSkipsWhenKeyMissing() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + Map extraArgs = new HashMap<>(); + extraArgs.put("model_name", "m"); + extraArgs.put("completionTokens", 50L); + + ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs)); + + verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + } + + @Test + void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() { + BaseChatModelSetup setup = mock(BaseChatModelSetup.class); + + Map zeroPrompt = new HashMap<>(); + zeroPrompt.put("model_name", "m"); + zeroPrompt.put("promptTokens", 0L); + zeroPrompt.put("completionTokens", 50L); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(zeroPrompt)); + + Map emptyModel = new HashMap<>(); + emptyModel.put("model_name", ""); + emptyModel.put("promptTokens", 100L); + emptyModel.put("completionTokens", 50L); + ChatModelAction.recordChatTokenMetrics(setup, responseWith(emptyModel)); + + verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), anyLong()); + } +}