Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.flink.agents.api.chat.model;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
Expand Down Expand Up @@ -56,22 +55,4 @@ public ResourceType getResourceType() {
*/
public abstract ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments);

/**
* Record token usage metrics for the given model.
*
* @param modelName the name of the model used
* @param promptTokens the number of prompt tokens
* @param completionTokens the number of completion tokens
*/
protected void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) {
FlinkAgentsMetricGroup metricGroup = getMetricGroup();
if (metricGroup == null) {
return;
}

FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
modelGroup.getCounter("promptTokens").inc(promptTokens);
modelGroup.getCounter("completionTokens").inc(completionTokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.prompt.Prompt;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceContext;
Expand Down Expand Up @@ -107,6 +108,23 @@ public void open() throws Exception {

public abstract Map<String, Object> getParameters();

/**
* Record token usage metrics for the given model on this setup's bound metric group.
*
* @param modelName the name of the model used
* @param promptTokens the number of prompt tokens
* @param completionTokens the number of completion tokens
*/
public void recordTokenMetrics(String modelName, long promptTokens, long completionTokens) {
FlinkAgentsMetricGroup metricGroup = getMetricGroup();
if (metricGroup == null) {
return;
}
FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
modelGroup.getCounter("promptTokens").inc(promptTokens);
modelGroup.getCounter("completionTokens").inc(completionTokens);
}

public ChatMessage chat(List<ChatMessage> messages) {
return this.chat(messages, Collections.emptyMap());
}
Expand All @@ -115,8 +133,6 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
Preconditions.checkNotNull(
connection,
"Connection is not initialized. Ensure open() is called before chat().");
// Pass metric group to connection for token usage tracking
connection.setMetricGroup(getMetricGroup());

// Format input messages if set prompt.
if (this.prompt != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,21 @@ public interface RunnerContext {
/**
* Gets the metric group for Flink Agents.
*
* <p>The returned group must only be accessed from the operator/mailbox (action) thread, not
* from inside a {@link #durableExecute} or {@link #durableExecuteAsync} callable, which runs on
* a separate thread pool.
*
* @return the metric group shared across all actions.
*/
FlinkAgentsMetricGroup getAgentMetricGroup();

/**
* Gets the individual metric group dedicated for each action.
*
* <p>The returned group must only be accessed from the operator/mailbox (action) thread, not
* from inside a {@link #durableExecute} or {@link #durableExecuteAsync} callable, which runs on
* a separate thread pool.
*
* @return the individual metric group specific to the current action.
*/
FlinkAgentsMetricGroup getActionMetricGroup();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,71 +18,60 @@

package org.apache.flink.agents.api.chat.model;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.metrics.Counter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

/** Test cases for BaseChatModelConnection token metrics functionality. */
class BaseChatModelConnectionTokenMetricsTest {
/** Test cases for BaseChatModelSetup token metrics functionality. */
class BaseChatModelSetupTokenMetricsTest {

private TestChatModelConnection connection;
private TestChatModelSetup setup;
private FlinkAgentsMetricGroup mockMetricGroup;
private FlinkAgentsMetricGroup mockModelGroup;
private Counter mockPromptTokensCounter;
private Counter mockCompletionTokensCounter;

/** Test implementation of BaseChatModelConnection for testing purposes. */
private static class TestChatModelConnection extends BaseChatModelConnection {
/** Test implementation of BaseChatModelSetup for testing purposes. */
private static class TestChatModelSetup extends BaseChatModelSetup {

public TestChatModelConnection(
ResourceDescriptor descriptor, ResourceContext resourceContext) {
public TestChatModelSetup(ResourceDescriptor descriptor, ResourceContext resourceContext) {
super(descriptor, resourceContext);
}

@Override
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments) {
// Simple test implementation
return new ChatMessage(MessageRole.ASSISTANT, "Test response");
}

// Expose protected method for testing
public void testRecordTokenMetrics(
String modelName, long promptTokens, long completionTokens) {
recordTokenMetrics(modelName, promptTokens, completionTokens);
public Map<String, Object> getParameters() {
return Collections.emptyMap();
}
}

@BeforeEach
void setUp() {
connection =
new TestChatModelConnection(
setup =
new TestChatModelSetup(
new ResourceDescriptor(
TestChatModelConnection.class.getName(), Collections.emptyMap()),
TestChatModelSetup.class.getName(), Collections.emptyMap()),
null);

// Create mock objects
mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
mockModelGroup = mock(FlinkAgentsMetricGroup.class);
mockPromptTokensCounter = mock(Counter.class);
mockCompletionTokensCounter = mock(Counter.class);

// Set up mock behavior
when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
Expand All @@ -91,13 +80,10 @@ void setUp() {
@Test
@DisplayName("Test token metrics are recorded when metric group is set")
void testRecordTokenMetricsWithMetricGroup() {
// Set the metric group
connection.setMetricGroup(mockMetricGroup);
setup.setMetricGroup(mockMetricGroup);

// Record token metrics
connection.testRecordTokenMetrics("gpt-4", 100, 50);
setup.recordTokenMetrics("gpt-4", 100, 50);

// Verify the metrics were recorded
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockModelGroup).getCounter("promptTokens");
verify(mockModelGroup).getCounter("completionTokens");
Expand All @@ -108,22 +94,16 @@ void testRecordTokenMetricsWithMetricGroup() {
@Test
@DisplayName("Test token metrics are not recorded when metric group is null")
void testRecordTokenMetricsWithoutMetricGroup() {
// Do not set metric group (should be null by default)
assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50));

// Record token metrics - should not throw
assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 100, 50));

// No metrics should be recorded
verifyNoInteractions(mockMetricGroup);
}

@Test
@DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName -> counters")
@DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> counters")
void testTokenMetricsHierarchy() {
// Set the metric group
connection.setMetricGroup(mockMetricGroup);
setup.setMetricGroup(mockMetricGroup);

// Record token metrics for different models
FlinkAgentsMetricGroup mockGpt35Group = mock(FlinkAgentsMetricGroup.class);
Counter mockGpt35PromptCounter = mock(Counter.class);
Counter mockGpt35CompletionCounter = mock(Counter.class);
Expand All @@ -132,13 +112,9 @@ void testTokenMetricsHierarchy() {
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);

// Record for gpt-4
connection.testRecordTokenMetrics("gpt-4", 100, 50);

// Record for gpt-3.5-turbo
connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);
setup.recordTokenMetrics("gpt-4", 100, 50);
setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100);

// Verify each model has its own counters
verify(mockMetricGroup).getSubGroup("gpt-4");
verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
verify(mockPromptTokensCounter).inc(100);
Expand All @@ -148,8 +124,8 @@ void testTokenMetricsHierarchy() {
}

@Test
@DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
@DisplayName("Test resource type is CHAT_MODEL")
void testResourceType() {
assertEquals(ResourceType.CHAT_MODEL_CONNECTION, connection.getResourceType());
assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public ChatMessage chat(
Message response = client.messages().create(params);
ChatMessage result = convertResponse(response, jsonPrefillApplied);

// Record token metrics
// Stash token usage
String modelName = null;
if (arguments != null && arguments.get("model") != null) {
modelName = arguments.get("model").toString();
Expand All @@ -142,8 +142,9 @@ public ChatMessage chat(
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
recordTokenMetrics(
modelName, response.usage().inputTokens(), response.usage().outputTokens());
result.getExtraArgs().put("model_name", modelName);
result.getExtraArgs().put("promptTokens", response.usage().inputTokens());
result.getExtraArgs().put("completionTokens", response.usage().outputTokens());
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,15 @@ public ChatMessage chat(
chatMessage.setToolCalls(convertedToolCalls);
}

// Record token metrics if model name is available
// Stash token usage if model name is available
if (modelName != null && !modelName.isBlank()) {
CompletionsUsage usage = completions.getUsage();
if (usage != null) {
recordTokenMetrics(
modelName, usage.getPromptTokens(), usage.getCompletionTokens());
chatMessage.getExtraArgs().put("model_name", modelName);
chatMessage.getExtraArgs().put("promptTokens", (long) usage.getPromptTokens());
chatMessage
.getExtraArgs()
.put("completionTokens", (long) usage.getCompletionTokens());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ public ChatMessage chat(
ConverseResponse response =
retryExecutor.execute(() -> client.converse(request), "BedrockConverse");

ChatMessage result = convertResponse(response);
if (response.usage() != null) {
recordTokenMetrics(
modelId, response.usage().inputTokens(), response.usage().outputTokens());
result.getExtraArgs().put("model_name", modelId);
result.getExtraArgs().put("promptTokens", response.usage().inputTokens().longValue());
result.getExtraArgs()
.put("completionTokens", response.usage().outputTokens().longValue());
}

return convertResponse(response);
return result;
}

private static boolean isRetryable(Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,14 @@ public ChatMessage chat(
chatMessage.setToolCalls(toolCalls);
}

// Record token metrics if model name is available
// Stash token usage if model name is available
if (modelName != null && !modelName.isBlank()) {
Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
Integer completionTokens = ollamaChatResponse.getEvalCount();
if (promptTokens != null && completionTokens != null) {
recordTokenMetrics(
modelName, promptTokens.longValue(), completionTokens.longValue());
extraArgs.put("model_name", modelName);
extraArgs.put("promptTokens", promptTokens.longValue());
extraArgs.put("completionTokens", completionTokens.longValue());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,11 @@ public ChatMessage chat(
if (modelOfAzureDeployment != null
&& !modelOfAzureDeployment.isBlank()
&& completion.usage().isPresent()) {
recordTokenMetrics(
modelOfAzureDeployment,
completion.usage().get().promptTokens(),
completion.usage().get().completionTokens());
response.getExtraArgs().put("model_name", modelOfAzureDeployment);
response.getExtraArgs()
.put("promptTokens", completion.usage().get().promptTokens());
response.getExtraArgs()
.put("completionTokens", completion.usage().get().completionTokens());
}

return response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,18 @@ public ChatMessage chat(
OpenAIChatCompletionsUtils.convertFromOpenAIMessage(
completion.choices().get(0).message());

// Record token metrics
// Stash token usage
if (completion.usage().isPresent()) {
String modelName = arguments != null ? (String) arguments.get("model") : null;
if (modelName == null || modelName.isBlank()) {
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
recordTokenMetrics(
modelName,
completion.usage().get().promptTokens(),
completion.usage().get().completionTokens());
response.getExtraArgs().put("model_name", modelName);
response.getExtraArgs()
.put("promptTokens", completion.usage().get().promptTokens());
response.getExtraArgs()
.put("completionTokens", completion.usage().get().completionTokens());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ public ChatMessage chat(
modelName = this.defaultModel;
}
if (modelName != null && !modelName.isBlank()) {
recordTokenMetrics(
modelName,
response.usage().get().inputTokens(),
response.usage().get().outputTokens());
result.getExtraArgs().put("model_name", modelName);
result.getExtraArgs().put("promptTokens", response.usage().get().inputTokens());
result.getExtraArgs()
.put("completionTokens", response.usage().get().outputTokens());
}
}

Expand Down
Loading
Loading