Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ public void open() throws Exception {
public abstract Map<String, Object> getParameters();

public ChatMessage chat(List<ChatMessage> messages) {
return this.chat(messages, Collections.emptyMap());
return this.chat(messages, Collections.emptyMap(), Collections.emptyMap());
}

public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) {
public ChatMessage chat(
List<ChatMessage> messages,
Map<String, Object> arguments,
Map<String, Object> parameters) {
Preconditions.checkNotNull(
connection,
"Connection is not initialized. Ensure open() is called before chat().");
Expand All @@ -124,15 +127,17 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
prompt instanceof Prompt,
"Prompt is not initialized. Ensure open() is called before chat().");
Prompt prompt = (Prompt) this.prompt;
Map<String, String> arguments = new HashMap<>();
for (ChatMessage message : messages) {
for (Map.Entry<String, Object> entry : message.getExtraArgs().entrySet()) {
arguments.put(entry.getKey(), entry.getValue().toString());
Map<String, String> stringified = new HashMap<>();
if (arguments != null) {
for (Map.Entry<String, Object> entry : arguments.entrySet()) {
stringified.put(
entry.getKey(),
entry.getValue() != null ? entry.getValue().toString() : "");
}
}

// append meaningful messages
List<ChatMessage> promptMessages = prompt.formatMessages(MessageRole.USER, arguments);
List<ChatMessage> promptMessages = prompt.formatMessages(MessageRole.USER, stringified);
for (ChatMessage message : messages) {
if ((message.getContent() != null && !message.getContent().isEmpty())
|| message.getRole() == MessageRole.ASSISTANT) {
Expand All @@ -150,7 +155,9 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
}

Map<String, Object> params = this.getParameters();
params.putAll(parameters);
if (parameters != null) {
params.putAll(parameters);
}
return connection.chat(messages, tools, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ public void open() {
}

@Override
public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) {
public ChatMessage chat(
List<ChatMessage> messages,
Map<String, Object> arguments,
Map<String, Object> parameters) {
checkState(
chatModelSetup != null,
"ChatModelSetup is not initialized. Cannot perform chat operation.");
Expand All @@ -75,6 +78,7 @@ public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> paramete
}

kwargs.put("messages", pythonMessages);
kwargs.put("arguments", arguments != null ? arguments : Collections.emptyMap());

Object pythonMessageResponse = adapter.callMethod(chatModelSetup, "chat", kwargs);
return adapter.fromPythonChatMessage(pythonMessageResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -39,17 +40,26 @@ public class ChatRequestEvent extends Event {
private static final ObjectMapper MAPPER = new ObjectMapper();

public ChatRequestEvent(
String model, List<ChatMessage> messages, @Nullable Object outputSchema) {
String model,
List<ChatMessage> messages,
@Nullable Map<String, Object> arguments,
@Nullable Object outputSchema) {
super(EVENT_TYPE);
setAttr("model", model);
setAttr("messages", new ArrayList<>(messages));
setAttr("arguments", arguments != null ? arguments : Collections.emptyMap());
if (outputSchema != null) {
setAttr("output_schema", outputSchema);
}
}

public ChatRequestEvent(
String model, List<ChatMessage> messages, @Nullable Object outputSchema) {
this(model, messages, null, outputSchema);
}

public ChatRequestEvent(String model, List<ChatMessage> messages) {
this(model, messages, null);
this(model, messages, null, null);
}

public ChatRequestEvent(UUID id, Map<String, Object> attributes) {
Expand Down Expand Up @@ -100,4 +110,11 @@ public List<ChatMessage> getMessages() {
public Object getOutputSchema() {
return getAttr("output_schema");
}

@JsonIgnore
@SuppressWarnings("unchecked")
public Map<String, Object> getArguments() {
Map<String, Object> args = (Map<String, Object>) getAttr("arguments");
return args != null ? args : Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -60,7 +62,10 @@ public Map<String, Object> getParameters() {
}

@Override
public ChatMessage chat(List<ChatMessage> messages, Map<String, Object> parameters) {
public ChatMessage chat(
List<ChatMessage> messages,
Map<String, Object> arguments,
Map<String, Object> parameters) {
// Simple test implementation that echoes the last user message

String lastUserContent = "";
Expand Down Expand Up @@ -229,6 +234,92 @@ void testChatResponseFormat() {
assertTrue(response.getContent().length() > 0);
}

/** Connection that captures the messages passed to it for assertions. */
private static class RecordingConnection extends BaseChatModelConnection {
List<ChatMessage> capturedMessages;

RecordingConnection() {
super(
new ResourceDescriptor(
RecordingConnection.class.getName(), Collections.emptyMap()),
null);
}

@Override
public ChatMessage chat(
List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments) {
this.capturedMessages = new ArrayList<>(messages);
return new ChatMessage(MessageRole.ASSISTANT, "ok");
}
}

/** Subclass that exposes setters so we can inject the connection and prompt directly. */
private static class RecordingChatModelSetup extends BaseChatModelSetup {
RecordingChatModelSetup(BaseChatModelConnection connection, Prompt prompt) {
super(
new ResourceDescriptor(
RecordingChatModelSetup.class.getName(), Collections.emptyMap()),
null);
this.connection = connection;
this.prompt = prompt;
}

@Override
public Map<String, Object> getParameters() {
return new HashMap<>();
}
}

@Test
@DisplayName("chat() fills prompt template from arguments parameter")
void testChatFillsTemplateFromArgumentsParameter() {
RecordingConnection connection = new RecordingConnection();
Prompt prompt = Prompt.fromText("Task: {key}");
RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt);

setup.chat(Collections.emptyList(), Map.of("key", "value"), Map.of());

assertNotNull(connection.capturedMessages);
assertEquals(1, connection.capturedMessages.size());
assertEquals("Task: value", connection.capturedMessages.get(0).getContent());
}

@Test
@DisplayName("chat() does not read template vars from ChatMessage.extraArgs")
void testChatDoesNotReadTemplateVarsFromExtraArgs() {
RecordingConnection connection = new RecordingConnection();
Prompt prompt = Prompt.fromText("Task: {key}");
RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt);

ChatMessage userMessage =
new ChatMessage(MessageRole.USER, "hello", Map.of("key", "value"));
setup.chat(List.of(userMessage), Map.of(), Map.of());

assertNotNull(connection.capturedMessages);
assertEquals(2, connection.capturedMessages.size());
assertEquals("Task: {key}", connection.capturedMessages.get(0).getContent());
assertEquals("hello", connection.capturedMessages.get(1).getContent());
}

@Test
@DisplayName("chat() re-fills prompt template on subsequent invocations when args supplied")
void testChatRefillsTemplateOnSubsequentInvocations() {
RecordingConnection connection = new RecordingConnection();
Prompt prompt = Prompt.fromText("Task: {key}");
RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt);

setup.chat(Collections.emptyList(), Map.of("key", "v1"), Map.of());
assertNotNull(connection.capturedMessages);
assertEquals(1, connection.capturedMessages.size());
assertEquals("Task: v1", connection.capturedMessages.get(0).getContent());

ChatMessage toolResponse = new ChatMessage(MessageRole.TOOL, "tool result");
setup.chat(List.of(toolResponse), Map.of("key", "v1"), Map.of());
assertEquals(2, connection.capturedMessages.size());
assertEquals("Task: v1", connection.capturedMessages.get(0).getContent());
assertEquals("tool result", connection.capturedMessages.get(1).getContent());
}

@Test
@DisplayName("Test chat with long input")
void testChatWithLongInput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ void testChat() {
ChatMessage inputMessage = mock(ChatMessage.class);
ChatMessage outputMessage = mock(ChatMessage.class);
List<ChatMessage> messages = Collections.singletonList(inputMessage);
Map<String, Object> arguments = new HashMap<>();
arguments.put("input", "value");
Map<String, Object> parameters = new HashMap<>();
parameters.put("temperature", 0.7);
parameters.put("max_tokens", 100);
Expand All @@ -105,7 +107,7 @@ void testChat() {
.thenReturn(pythonOutputMessage);
when(mockAdapter.fromPythonChatMessage(pythonOutputMessage)).thenReturn(outputMessage);

ChatMessage result = pythonChatModelSetup.chat(messages, parameters);
ChatMessage result = pythonChatModelSetup.chat(messages, arguments, parameters);

assertThat(result).isEqualTo(outputMessage);

Expand All @@ -117,8 +119,10 @@ void testChat() {
argThat(
kwargs -> {
assertThat(kwargs).containsKey("messages");
assertThat(kwargs).containsKey("arguments");
assertThat(kwargs).containsKey("temperature");
assertThat(kwargs).containsKey("max_tokens");
assertThat(kwargs.get("arguments")).isEqualTo(arguments);
assertThat(kwargs.get("temperature")).isEqualTo(0.7);
assertThat(kwargs.get("max_tokens")).isEqualTo(100);
List<?> pythonMessages = (List<?>) kwargs.get("messages");
Expand All @@ -136,9 +140,10 @@ void testChatWithNullChatModelSetupThrowsException() {

ChatMessage inputMessage = mock(ChatMessage.class);
List<ChatMessage> messages = Collections.singletonList(inputMessage);
Map<String, Object> arguments = new HashMap<>();
Map<String, Object> parameters = new HashMap<>();

assertThatThrownBy(() -> setupWithNullModel.chat(messages, parameters))
assertThatThrownBy(() -> setupWithNullModel.chat(messages, arguments, parameters))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("ChatModelSetup is not initialized")
.hasMessageContaining("Cannot perform chat operation");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ public static void processInput(Event event, RunnerContext ctx) throws Exception
"{\n\"id\": %s,\n\"score_histogram\": %s,\n\"unsatisfied_reasons\": %s\n}",
summary.getId(), summary.getScoreHist(), summary.getUnsatisfiedReasons());

ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content));
ChatMessage msg = new ChatMessage(MessageRole.USER, "");

ctx.sendEvent(new ChatRequestEvent("generateSuggestionModel", List.of(msg)));
ctx.sendEvent(
new ChatRequestEvent(
"generateSuggestionModel", List.of(msg), Map.of("input", content), null));
}

/** Process chat response event. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ public static void processInput(Event event, RunnerContext ctx) throws Exception
String.format(
"{\n" + "\"id\": %s,\n" + "\"review\": \"%s\"\n" + "}",
inputObj.getId(), inputObj.getReview());
ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content));
ChatMessage msg = new ChatMessage(MessageRole.USER, "");

ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg)));
ctx.sendEvent(
new ChatRequestEvent(
"reviewAnalysisModel", List.of(msg), Map.of("input", content), null));
}

@Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ public static void processInput(Event event, RunnerContext ctx) throws Exception
String.format(
"{\n" + "\"id\": \"%s\",\n" + "\"review\": \"%s\"\n" + "}",
productId, reviewText);
ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content));
ChatMessage msg = new ChatMessage(MessageRole.USER, "");

ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg)));
ctx.sendEvent(
new ChatRequestEvent(
"reviewAnalysisModel", List.of(msg), Map.of("input", content), null));
}

@Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE})
Expand Down
Loading
Loading