From c9717aa16ef74fc940c67e170b642bea6ecfa601 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 21 May 2026 19:07:27 -0700 Subject: [PATCH 1/3] [api][plan][runtime] Separate prompt arguments from message extra_args in BaseChatModelSetup.chat() `BaseChatModelSetup.chat()` previously filled prompt templates by flattening every input message's `extra_args` into a single map. This conflated chat metadata with template variables. Introduce an explicit `arguments` parameter on `chat()` (Java + Python) and carry the same field on `ChatRequestEvent`, then thread it through `ChatModelAction` to the setup on both round 1 and tool-response continuations so multi-turn flows keep re-filling the template correctly. `ChatMessage.extra_args` is unchanged and still carries `externalId`, `STRUCTURED_OUTPUT`, OpenAI `refusal`, Ollama `reasoning`, and other provider-specific metadata used by chat-model connections. Closes #220 --- .../api/chat/model/BaseChatModelSetup.java | 23 +++-- .../model/python/PythonChatModelSetup.java | 6 +- .../agents/api/event/ChatRequestEvent.java | 21 ++++- .../api/chat/model/BaseChatModelTest.java | 93 ++++++++++++++++++- .../python/PythonChatModelSetupTest.java | 9 +- .../agents/ProductSuggestionAgent.java | 6 +- .../examples/agents/ReviewAnalysisAgent.java | 6 +- .../agents/TableReviewAnalysisAgent.java | 6 +- .../agents/plan/actions/ChatModelAction.java | 30 +++++- .../actions/ChatModelActionRetryTest.java | 59 +++++++++++- .../api/chat_models/chat_model.py | 27 +++--- .../chat_models/tests/test_chat_model_base.py | 86 ++++++++++++++++- python/flink_agents/api/events/chat_event.py | 14 ++- .../built_in_action_async_execution_test.py | 7 +- .../e2e_tests_mcp/mcp_test.py | 13 ++- .../agents/product_suggestion_agent.py | 5 +- .../agents/review_analysis_agent.py | 10 +- .../agents/table_review_analysis_agent.py | 10 +- .../plan/actions/chat_model_action.py | 17 +++- .../actions/test_chat_model_action_retry.py | 73 ++++++++++++++- .../runtime/java/java_chat_model.py | 17 +++- .../runtime/tests/test_built_in_actions.py | 24 ++--- 22 files changed, 484 insertions(+), 78 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java index af7ed10b2..efd9294d6 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java @@ -108,10 +108,13 @@ public void open() throws Exception { public abstract Map getParameters(); public ChatMessage chat(List messages) { - return this.chat(messages, Collections.emptyMap()); + return this.chat(messages, Collections.emptyMap(), Collections.emptyMap()); } - public ChatMessage chat(List messages, Map parameters) { + public ChatMessage chat( + List messages, + Map arguments, + Map parameters) { Preconditions.checkNotNull( connection, "Connection is not initialized. Ensure open() is called before chat()."); @@ -124,15 +127,17 @@ public ChatMessage chat(List messages, Map paramete prompt instanceof Prompt, "Prompt is not initialized. Ensure open() is called before chat()."); Prompt prompt = (Prompt) this.prompt; - Map arguments = new HashMap<>(); - for (ChatMessage message : messages) { - for (Map.Entry entry : message.getExtraArgs().entrySet()) { - arguments.put(entry.getKey(), entry.getValue().toString()); + Map stringified = new HashMap<>(); + if (arguments != null) { + for (Map.Entry entry : arguments.entrySet()) { + stringified.put( + entry.getKey(), + entry.getValue() != null ? entry.getValue().toString() : ""); } } // append meaningful messages - List promptMessages = prompt.formatMessages(MessageRole.USER, arguments); + List promptMessages = prompt.formatMessages(MessageRole.USER, stringified); for (ChatMessage message : messages) { if ((message.getContent() != null && !message.getContent().isEmpty()) || message.getRole() == MessageRole.ASSISTANT) { @@ -150,7 +155,9 @@ public ChatMessage chat(List messages, Map paramete } Map params = this.getParameters(); - params.putAll(parameters); + if (parameters != null) { + params.putAll(parameters); + } return connection.chat(messages, tools, params); } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java index a15cede18..0c70354f4 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetup.java @@ -62,7 +62,10 @@ public void open() { } @Override - public ChatMessage chat(List messages, Map parameters) { + public ChatMessage chat( + List messages, + Map arguments, + Map parameters) { checkState( chatModelSetup != null, "ChatModelSetup is not initialized. Cannot perform chat operation."); @@ -75,6 +78,7 @@ public ChatMessage chat(List messages, Map paramete } kwargs.put("messages", pythonMessages); + kwargs.put("arguments", arguments != null ? arguments : Collections.emptyMap()); Object pythonMessageResponse = adapter.callMethod(chatModelSetup, "chat", kwargs); return adapter.fromPythonChatMessage(pythonMessageResponse); diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java index b1cdcdbc2..7e82f1c65 100644 --- a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java +++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java @@ -26,6 +26,7 @@ import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,17 +40,26 @@ public class ChatRequestEvent extends Event { private static final ObjectMapper MAPPER = new ObjectMapper(); public ChatRequestEvent( - String model, List messages, @Nullable Object outputSchema) { + String model, + List messages, + @Nullable Map arguments, + @Nullable Object outputSchema) { super(EVENT_TYPE); setAttr("model", model); setAttr("messages", new ArrayList<>(messages)); + setAttr("arguments", arguments != null ? arguments : Collections.emptyMap()); if (outputSchema != null) { setAttr("output_schema", outputSchema); } } + public ChatRequestEvent( + String model, List messages, @Nullable Object outputSchema) { + this(model, messages, null, outputSchema); + } + public ChatRequestEvent(String model, List messages) { - this(model, messages, null); + this(model, messages, null, null); } public ChatRequestEvent(UUID id, Map attributes) { @@ -100,4 +110,11 @@ public List getMessages() { public Object getOutputSchema() { return getAttr("output_schema"); } + + @JsonIgnore + @SuppressWarnings("unchecked") + public Map getArguments() { + Map args = (Map) getAttr("arguments"); + return args != null ? args : Collections.emptyMap(); + } } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java index 61c9f823a..2f84b36b1 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java @@ -24,10 +24,12 @@ import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -60,7 +62,10 @@ public Map getParameters() { } @Override - public ChatMessage chat(List messages, Map parameters) { + public ChatMessage chat( + List messages, + Map arguments, + Map parameters) { // Simple test implementation that echoes the last user message String lastUserContent = ""; @@ -229,6 +234,92 @@ void testChatResponseFormat() { assertTrue(response.getContent().length() > 0); } + /** Connection that captures the messages passed to it for assertions. */ + private static class RecordingConnection extends BaseChatModelConnection { + List capturedMessages; + + RecordingConnection() { + super( + new ResourceDescriptor( + RecordingConnection.class.getName(), Collections.emptyMap()), + null); + } + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + this.capturedMessages = new ArrayList<>(messages); + return new ChatMessage(MessageRole.ASSISTANT, "ok"); + } + } + + /** Subclass that exposes setters so we can inject the connection and prompt directly. */ + private static class RecordingChatModelSetup extends BaseChatModelSetup { + RecordingChatModelSetup(BaseChatModelConnection connection, Prompt prompt) { + super( + new ResourceDescriptor( + RecordingChatModelSetup.class.getName(), Collections.emptyMap()), + null); + this.connection = connection; + this.prompt = prompt; + } + + @Override + public Map getParameters() { + return new HashMap<>(); + } + } + + @Test + @DisplayName("chat() fills prompt template from arguments parameter") + void testChatFillsTemplateFromArgumentsParameter() { + RecordingConnection connection = new RecordingConnection(); + Prompt prompt = Prompt.fromText("Task: {key}"); + RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt); + + setup.chat(Collections.emptyList(), Map.of("key", "value"), Map.of()); + + assertNotNull(connection.capturedMessages); + assertEquals(1, connection.capturedMessages.size()); + assertEquals("Task: value", connection.capturedMessages.get(0).getContent()); + } + + @Test + @DisplayName("chat() does not read template vars from ChatMessage.extraArgs") + void testChatDoesNotReadTemplateVarsFromExtraArgs() { + RecordingConnection connection = new RecordingConnection(); + Prompt prompt = Prompt.fromText("Task: {key}"); + RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt); + + ChatMessage userMessage = + new ChatMessage(MessageRole.USER, "hello", Map.of("key", "value")); + setup.chat(List.of(userMessage), Map.of(), Map.of()); + + assertNotNull(connection.capturedMessages); + assertEquals(2, connection.capturedMessages.size()); + assertEquals("Task: {key}", connection.capturedMessages.get(0).getContent()); + assertEquals("hello", connection.capturedMessages.get(1).getContent()); + } + + @Test + @DisplayName("chat() re-fills prompt template on subsequent invocations when args supplied") + void testChatRefillsTemplateOnSubsequentInvocations() { + RecordingConnection connection = new RecordingConnection(); + Prompt prompt = Prompt.fromText("Task: {key}"); + RecordingChatModelSetup setup = new RecordingChatModelSetup(connection, prompt); + + setup.chat(Collections.emptyList(), Map.of("key", "v1"), Map.of()); + assertNotNull(connection.capturedMessages); + assertEquals(1, connection.capturedMessages.size()); + assertEquals("Task: v1", connection.capturedMessages.get(0).getContent()); + + ChatMessage toolResponse = new ChatMessage(MessageRole.TOOL, "tool result"); + setup.chat(List.of(toolResponse), Map.of("key", "v1"), Map.of()); + assertEquals(2, connection.capturedMessages.size()); + assertEquals("Task: v1", connection.capturedMessages.get(0).getContent()); + assertEquals("tool result", connection.capturedMessages.get(1).getContent()); + } + @Test @DisplayName("Test chat with long input") void testChatWithLongInput() { diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java index 3e8b1ebb5..5fe9a58d6 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/python/PythonChatModelSetupTest.java @@ -93,6 +93,8 @@ void testChat() { ChatMessage inputMessage = mock(ChatMessage.class); ChatMessage outputMessage = mock(ChatMessage.class); List messages = Collections.singletonList(inputMessage); + Map arguments = new HashMap<>(); + arguments.put("input", "value"); Map parameters = new HashMap<>(); parameters.put("temperature", 0.7); parameters.put("max_tokens", 100); @@ -105,7 +107,7 @@ void testChat() { .thenReturn(pythonOutputMessage); when(mockAdapter.fromPythonChatMessage(pythonOutputMessage)).thenReturn(outputMessage); - ChatMessage result = pythonChatModelSetup.chat(messages, parameters); + ChatMessage result = pythonChatModelSetup.chat(messages, arguments, parameters); assertThat(result).isEqualTo(outputMessage); @@ -117,8 +119,10 @@ void testChat() { argThat( kwargs -> { assertThat(kwargs).containsKey("messages"); + assertThat(kwargs).containsKey("arguments"); assertThat(kwargs).containsKey("temperature"); assertThat(kwargs).containsKey("max_tokens"); + assertThat(kwargs.get("arguments")).isEqualTo(arguments); assertThat(kwargs.get("temperature")).isEqualTo(0.7); assertThat(kwargs.get("max_tokens")).isEqualTo(100); List pythonMessages = (List) kwargs.get("messages"); @@ -136,9 +140,10 @@ void testChatWithNullChatModelSetupThrowsException() { ChatMessage inputMessage = mock(ChatMessage.class); List messages = Collections.singletonList(inputMessage); + Map arguments = new HashMap<>(); Map parameters = new HashMap<>(); - assertThatThrownBy(() -> setupWithNullModel.chat(messages, parameters)) + assertThatThrownBy(() -> setupWithNullModel.chat(messages, arguments, parameters)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("ChatModelSetup is not initialized") .hasMessageContaining("Cannot perform chat operation"); diff --git a/examples/src/main/java/org/apache/flink/agents/examples/agents/ProductSuggestionAgent.java b/examples/src/main/java/org/apache/flink/agents/examples/agents/ProductSuggestionAgent.java index bd816f2a1..303ef4209 100644 --- a/examples/src/main/java/org/apache/flink/agents/examples/agents/ProductSuggestionAgent.java +++ b/examples/src/main/java/org/apache/flink/agents/examples/agents/ProductSuggestionAgent.java @@ -87,9 +87,11 @@ public static void processInput(Event event, RunnerContext ctx) throws Exception "{\n\"id\": %s,\n\"score_histogram\": %s,\n\"unsatisfied_reasons\": %s\n}", summary.getId(), summary.getScoreHist(), summary.getUnsatisfiedReasons()); - ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content)); + ChatMessage msg = new ChatMessage(MessageRole.USER, ""); - ctx.sendEvent(new ChatRequestEvent("generateSuggestionModel", List.of(msg))); + ctx.sendEvent( + new ChatRequestEvent( + "generateSuggestionModel", List.of(msg), Map.of("input", content), null)); } /** Process chat response event. */ diff --git a/examples/src/main/java/org/apache/flink/agents/examples/agents/ReviewAnalysisAgent.java b/examples/src/main/java/org/apache/flink/agents/examples/agents/ReviewAnalysisAgent.java index bbcba5434..05dff401b 100644 --- a/examples/src/main/java/org/apache/flink/agents/examples/agents/ReviewAnalysisAgent.java +++ b/examples/src/main/java/org/apache/flink/agents/examples/agents/ReviewAnalysisAgent.java @@ -101,9 +101,11 @@ public static void processInput(Event event, RunnerContext ctx) throws Exception String.format( "{\n" + "\"id\": %s,\n" + "\"review\": \"%s\"\n" + "}", inputObj.getId(), inputObj.getReview()); - ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content)); + ChatMessage msg = new ChatMessage(MessageRole.USER, ""); - ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg))); + ctx.sendEvent( + new ChatRequestEvent( + "reviewAnalysisModel", List.of(msg), Map.of("input", content), null)); } @Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE}) diff --git a/examples/src/main/java/org/apache/flink/agents/examples/agents/TableReviewAnalysisAgent.java b/examples/src/main/java/org/apache/flink/agents/examples/agents/TableReviewAnalysisAgent.java index 4f49f2f75..77b314f4e 100644 --- a/examples/src/main/java/org/apache/flink/agents/examples/agents/TableReviewAnalysisAgent.java +++ b/examples/src/main/java/org/apache/flink/agents/examples/agents/TableReviewAnalysisAgent.java @@ -121,9 +121,11 @@ public static void processInput(Event event, RunnerContext ctx) throws Exception String.format( "{\n" + "\"id\": \"%s\",\n" + "\"review\": \"%s\"\n" + "}", productId, reviewText); - ChatMessage msg = new ChatMessage(MessageRole.USER, "", Map.of("input", content)); + ChatMessage msg = new ChatMessage(MessageRole.USER, ""); - ctx.sendEvent(new ChatRequestEvent("reviewAnalysisModel", List.of(msg))); + ctx.sendEvent( + new ChatRequestEvent( + "reviewAnalysisModel", List.of(msg), Map.of("input", content), null)); } @Action(listenEventTypes = {ChatResponseEvent.EVENT_TYPE}) diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index 997fb28b9..b2a7b6bbe 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -60,6 +60,7 @@ public class ChatModelAction { private static final String INITIAL_REQUEST_ID = "initialRequestId"; private static final String MODEL = "model"; private static final String OUTPUT_SCHEMA = "outputSchema"; + private static final String ARGUMENTS = "arguments"; private static final String RETRY_STATS_CONTEXT = "_RETRY_STATS_CONTEXT"; private static final String TOTAL_RETRY_COUNT = "totalRetryCount"; private static final String TOTAL_RETRY_WAIT_SEC = "totalRetryWaitSec"; @@ -108,6 +109,7 @@ private static void saveToolRequestEventContext( UUID toolRequestEventId, UUID initialRequestId, String model, + Map arguments, Object outputSchema) throws Exception { Map toolRequestEventContext; @@ -120,6 +122,7 @@ private static void saveToolRequestEventContext( Map context = new HashMap<>(); context.put(INITIAL_REQUEST_ID, initialRequestId); context.put(MODEL, model); + context.put(ARGUMENTS, arguments != null ? arguments : Collections.emptyMap()); if (outputSchema != null) { context.put(OUTPUT_SCHEMA, outputSchema); } @@ -185,6 +188,7 @@ private static void handleToolCalls( String model, BaseChatModelSetup chatModel, List messages, + Map arguments, Object outputSchema, RunnerContext ctx) throws Exception { @@ -203,6 +207,7 @@ private static void handleToolCalls( toolRequestEvent.getId(), initialRequestId, model, + arguments, outputSchema); ctx.sendEvent(toolRequestEvent); @@ -294,6 +299,7 @@ public static void chat( UUID initialRequestId, String model, List messages, + Map arguments, @Nullable Object outputSchema, RunnerContext ctx) throws Exception { @@ -339,7 +345,7 @@ public Class getResultClass() { @Override public ChatMessage call() throws Exception { - return chatModel.chat(messages, Map.of()); + return chatModel.chat(messages, arguments, Map.of()); } }; @@ -393,7 +399,14 @@ public ChatMessage call() throws Exception { if (!Objects.requireNonNull(response).getToolCalls().isEmpty()) { handleToolCalls( - response, initialRequestId, model, chatModel, messages, outputSchema, ctx); + response, + initialRequestId, + model, + chatModel, + messages, + arguments, + outputSchema, + ctx); } else { Map retryStats = getRetryStats(ctx.getSensoryMemory(), initialRequestId); int totalRetryCount = retryStats.get(TOTAL_RETRY_COUNT).intValue(); @@ -410,9 +423,16 @@ public ChatMessage call() throws Exception { private static void processChatRequest(ChatRequestEvent event, RunnerContext ctx) throws Exception { - chat(event.getId(), event.getModel(), event.getMessages(), event.getOutputSchema(), ctx); + chat( + event.getId(), + event.getModel(), + event.getMessages(), + event.getArguments(), + event.getOutputSchema(), + ctx); } + @SuppressWarnings("unchecked") private static void processToolResponse(ToolResponseEvent event, RunnerContext ctx) throws Exception { MemoryObject sensoryMem = ctx.getSensoryMemory(); @@ -422,6 +442,8 @@ private static void processToolResponse(ToolResponseEvent event, RunnerContext c UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID); String model = (String) context.get(MODEL); + Map arguments = + (Map) context.getOrDefault(ARGUMENTS, Map.of()); Object outputSchema = context.get(OUTPUT_SCHEMA); Map responses = event.getResponses(); @@ -455,7 +477,7 @@ private static void processToolResponse(ToolResponseEvent event, RunnerContext c Collections.emptyList(), toolResponseMessages); - chat(initialRequestId, model, messages, outputSchema, ctx); + chat(initialRequestId, model, messages, arguments, outputSchema, ctx); } /** diff --git a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java index 706bec282..a94698e5a 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionRetryTest.java @@ -27,12 +27,15 @@ import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.event.ChatResponseEvent; +import org.apache.flink.agents.api.event.ToolResponseEvent; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.metrics.Counter; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -102,7 +105,7 @@ void tearDown() throws Exception { @Test void chatSucceedsWithoutRetry_retryCountIsZero() throws Exception { configureRetryStrategy(3, 1); - when(mockChatModel.chat(any(), any())) + when(mockChatModel.chat(any(), any(), any())) .thenReturn(new ChatMessage(MessageRole.ASSISTANT, "hello")); UUID requestId = UUID.randomUUID(); @@ -110,6 +113,7 @@ void chatSucceedsWithoutRetry_retryCountIsZero() throws Exception { requestId, "test-model", List.of(new ChatMessage(MessageRole.USER, "hi")), + Map.of(), null, mockCtx); @@ -128,7 +132,7 @@ void chatRetriesWithExponentialBackoff() throws Exception { configureRetryStrategy(3, 1); AtomicInteger callCount = new AtomicInteger(0); - when(mockChatModel.chat(any(), any())) + when(mockChatModel.chat(any(), any(), any())) .thenAnswer( inv -> { int count = callCount.incrementAndGet(); @@ -145,6 +149,7 @@ void chatRetriesWithExponentialBackoff() throws Exception { requestId, "test-model", List.of(new ChatMessage(MessageRole.USER, "hi")), + Map.of(), null, mockCtx); long elapsed = System.currentTimeMillis() - startTime; @@ -167,7 +172,8 @@ void chatRetriesWithExponentialBackoff() throws Exception { void chatExhaustsRetriesAndThrows() { configureRetryStrategy(2, 0); - when(mockChatModel.chat(any(), any())).thenThrow(new RuntimeException("persistent error")); + when(mockChatModel.chat(any(), any(), any())) + .thenThrow(new RuntimeException("persistent error")); UUID requestId = UUID.randomUUID(); @@ -177,6 +183,7 @@ void chatExhaustsRetriesAndThrows() { requestId, "test-model", List.of(new ChatMessage(MessageRole.USER, "hi")), + Map.of(), null, mockCtx)) .isInstanceOf(RuntimeException.class) @@ -211,6 +218,52 @@ void retryWaitIntervalDefaultValue() { assertThat(AgentExecutionOptions.RETRY_WAIT_INTERVAL.getDefaultValue()).isEqualTo(1); } + @Test + void processToolResponseForwardsSavedArgumentsToChat() throws Exception { + configureRetryStrategy(0, 0); + + UUID initialRequestId = UUID.randomUUID(); + UUID toolRequestEventId = UUID.randomUUID(); + String toolCallId = "call-1"; + Map savedArguments = Map.of("k", "v"); + + // Pre-seed sensory memory with the tool-request-event context that + // processToolResponse will look up. This simulates a prior chat round + // that produced a tool call. + Map toolRequestEventContext = new HashMap<>(); + Map contextEntry = new HashMap<>(); + contextEntry.put("initialRequestId", initialRequestId); + contextEntry.put("model", "test-model"); + contextEntry.put("arguments", savedArguments); + toolRequestEventContext.put(toolRequestEventId, contextEntry); + sensoryMemory.set("_TOOL_REQUEST_EVENT_CONTEXT", toolRequestEventContext); + + // Pre-seed the tool-call context with the initial messages so + // updateToolCallContext can extend them with the tool response. + Map toolCallContext = new HashMap<>(); + toolCallContext.put( + initialRequestId, + new ArrayList<>(List.of(new ChatMessage(MessageRole.USER, "hi")))); + sensoryMemory.set("_TOOL_CALL_CONTEXT", toolCallContext); + + when(mockChatModel.chat(any(), any(), any())) + .thenReturn(new ChatMessage(MessageRole.ASSISTANT, "done")); + + ToolResponseEvent toolResponseEvent = + new ToolResponseEvent( + toolRequestEventId, + Map.of(toolCallId, ToolResponse.success("42")), + Map.of(toolCallId, true), + Map.of()); + + ChatModelAction.processChatRequestOrToolResponse(toolResponseEvent, mockCtx); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argumentsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockChatModel).chat(any(), argumentsCaptor.capture(), any()); + assertThat(argumentsCaptor.getValue()).isEqualTo(savedArguments); + } + // --- Helper methods --- private void configureRetryStrategy(int maxRetries, int waitIntervalSec) { diff --git a/python/flink_agents/api/chat_models/chat_model.py b/python/flink_agents/api/chat_models/chat_model.py index 960559793..82da6ae86 100644 --- a/python/flink_agents/api/chat_models/chat_model.py +++ b/python/flink_agents/api/chat_models/chat_model.py @@ -17,7 +17,7 @@ ################################################################################# import re from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, List, Sequence, Tuple, cast +from typing import Any, ClassVar, Dict, List, Mapping, Sequence, Tuple, cast from pydantic import Field, PrivateAttr from typing_extensions import override @@ -197,10 +197,15 @@ def open(self) -> None: for tool_name in self.tools ] - def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + def chat( + self, + messages: Sequence[ChatMessage], + arguments: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> ChatMessage: """Execute chat conversation. - 1. Apply prompt template (if any) + 1. Apply prompt template (if any), filled from ``arguments`` 2. Bind tools (if any) 3. Call ChatModelConnection to perform actual communication 4. Process response @@ -209,6 +214,10 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: ---------- messages : Sequence[ChatMessage] Input message sequence + arguments : Mapping[str, Any] | None + Variables used to fill the prompt template, if a prompt resource is + configured. Values are stringified via ``str()`` to match the + ``Prompt.format_messages`` contract. **kwargs : Any Additional parameters passed to the model service @@ -219,14 +228,10 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: """ # Apply prompt template if self.prompt is not None: - input_variable = {} - - # fill the prompt template - for msg in messages: - # Convert Any values to str to match format_messages signature - str_extra_args = {k: str(v) for k, v in msg.extra_args.items()} - input_variable.update(str_extra_args) - prompt_messages = self._get_prompt().format_messages(**input_variable) + str_arguments: Dict[str, str] = ( + {k: str(v) for k, v in arguments.items()} if arguments else {} + ) + prompt_messages = self._get_prompt().format_messages(**str_arguments) # append meaningful messages for msg in messages: diff --git a/python/flink_agents/api/chat_models/tests/test_chat_model_base.py b/python/flink_agents/api/chat_models/tests/test_chat_model_base.py index 651061f49..fae5c644e 100644 --- a/python/flink_agents/api/chat_models/tests/test_chat_model_base.py +++ b/python/flink_agents/api/chat_models/tests/test_chat_model_base.py @@ -15,12 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Dict +from typing import Any, Dict, List, Sequence import pytest -from pydantic import ValidationError +from pydantic import Field, ValidationError -from flink_agents.api.chat_models.chat_model import BaseChatModelSetup +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import ( + BaseChatModelConnection, + BaseChatModelSetup, +) +from flink_agents.api.prompts.prompt import Prompt +from flink_agents.api.tools.tool import Tool class _MinimalChatModelSetup(BaseChatModelSetup): @@ -35,6 +41,38 @@ def model_kwargs(self) -> Dict[str, Any]: return {"model": self.model} +class _RecordingConnection(BaseChatModelConnection): + """Connection that captures the messages it receives for inspection.""" + + captured_messages: List[ChatMessage] = Field(default_factory=list) + + def chat( + self, + messages: Sequence[ChatMessage], + tools: List[Tool] | None = None, + **kwargs: Any, + ) -> ChatMessage: + self.captured_messages = list(messages) + return ChatMessage(role=MessageRole.ASSISTANT, content="ok") + + +class _RecordingChatModelSetup(BaseChatModelSetup): + """Subclass that lets tests inject a connection without calling open().""" + + @property + def model_kwargs(self) -> Dict[str, Any]: + return {} + + +def _build_setup( + prompt: Prompt, +) -> tuple[_RecordingChatModelSetup, _RecordingConnection]: + setup = _RecordingChatModelSetup(connection="c", model="m", prompt=prompt) + connection = _RecordingConnection() + setup._resolved_connection = connection + return setup, connection + + def test_inherits_model_field_from_base() -> None: """A subclass that omits `model` still exposes it via inheritance.""" setup = _MinimalChatModelSetup(connection="c", model="m1") @@ -45,3 +83,45 @@ def test_missing_model_raises_validation_error() -> None: """Constructing without `model` must raise a Pydantic ValidationError.""" with pytest.raises(ValidationError): _MinimalChatModelSetup(connection="c") + + +def test_chat_fills_template_from_arguments_parameter() -> None: + """chat() fills the prompt template from the `arguments` parameter.""" + prompt = Prompt.from_text(text="Task: {key}") + setup, connection = _build_setup(prompt) + + setup.chat([], arguments={"key": "value"}) + + assert len(connection.captured_messages) == 1 + assert connection.captured_messages[0].content == "Task: value" + + +def test_chat_does_not_read_template_vars_from_extra_args() -> None: + """chat() must not read template variables from ChatMessage.extra_args.""" + prompt = Prompt.from_text(text="Task: {key}") + setup, connection = _build_setup(prompt) + + user_message = ChatMessage( + role=MessageRole.USER, content="hello", extra_args={"key": "value"} + ) + setup.chat([user_message], arguments={}) + + assert len(connection.captured_messages) == 2 + assert connection.captured_messages[0].content == "Task: {key}" + assert connection.captured_messages[1].content == "hello" + + +def test_chat_refills_template_on_subsequent_invocations() -> None: + """Each chat() invocation must re-fill the prompt template from the args.""" + prompt = Prompt.from_text(text="Task: {key}") + setup, connection = _build_setup(prompt) + + setup.chat([], arguments={"key": "v1"}) + assert len(connection.captured_messages) == 1 + assert connection.captured_messages[0].content == "Task: v1" + + tool_response = ChatMessage(role=MessageRole.TOOL, content="tool result") + setup.chat([tool_response], arguments={"key": "v1"}) + assert len(connection.captured_messages) == 2 + assert connection.captured_messages[0].content == "Task: v1" + assert connection.captured_messages[1].content == "tool result" diff --git a/python/flink_agents/api/events/chat_event.py b/python/flink_agents/api/events/chat_event.py index 3423f70b4..182df6e0b 100644 --- a/python/flink_agents/api/events/chat_event.py +++ b/python/flink_agents/api/events/chat_event.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import ClassVar, List +from typing import Any, ClassVar, Dict, List try: from typing import override @@ -37,6 +37,9 @@ class ChatRequestEvent(Event): The name of the chat model to be chatted with. messages : List[ChatMessage] The input to the chat model. + arguments : Dict[str, Any] + Variables used to fill the chat model's prompt template, if a prompt + resource is configured on the chat model setup. Empty by default. output_schema: OutputSchema | None The expected output schema of the chat model final response. Optional. """ @@ -47,6 +50,7 @@ def __init__( self, model: str, messages: List[ChatMessage], + arguments: Dict[str, Any] | None = None, output_schema: OutputSchema | None = None, ) -> None: """Create a ChatRequestEvent.""" @@ -55,6 +59,7 @@ def __init__( attributes={ "model": model, "messages": messages, + "arguments": arguments if arguments is not None else {}, "output_schema": output_schema, }, ) @@ -75,6 +80,7 @@ def from_event(cls, event: Event) -> "ChatRequestEvent": return ChatRequestEvent( model=event.attributes["model"], messages=messages, + arguments=event.attributes.get("arguments"), output_schema=output_schema_raw, ) @@ -88,6 +94,12 @@ def messages(self) -> List[ChatMessage]: """Return the chat messages.""" return self.get_attr("messages") + @property + def arguments(self) -> Dict[str, Any]: + """Return the prompt-template arguments, empty if not set.""" + args = self.get_attr("arguments") + return args if args is not None else {} + @property def output_schema(self) -> OutputSchema | None: """Return the expected output schema, if any.""" diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py index f5fefd725..f81625aa6 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py @@ -91,11 +91,8 @@ def process_input(event: Event, ctx: RunnerContext) -> None: ctx.send_event( ChatRequestEvent( model="slow_chat_model", - messages=[ - ChatMessage( - role=MessageRole.USER, content=input, extra_args={"task": input} - ) - ], + messages=[ChatMessage(role=MessageRole.USER, content=input)], + arguments={"task": input}, ) ) diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py index ea74ae00f..0b2f1fb56 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/e2e_tests_mcp/mcp_test.py @@ -120,9 +120,13 @@ def process_input(event: Event, ctx: RunnerContext) -> None: if mcp_mode == "with_prompts": # Send chat request with MCP prompt variables # The prompt template will be filled with a and b values - msg = ChatMessage( - role=MessageRole.USER, - extra_args={"a": str(input_data.a), "b": str(input_data.b)}, + msg = ChatMessage(role=MessageRole.USER) + ctx.send_event( + ChatRequestEvent( + model="math_chat_model", + messages=[msg], + arguments={"a": str(input_data.a), "b": str(input_data.b)}, + ) ) else: # Send chat request asking to use the add tool @@ -130,8 +134,7 @@ def process_input(event: Event, ctx: RunnerContext) -> None: role=MessageRole.USER, content=f"Please use the add tool to calculate the sum of {input_data.a} and {input_data.b}.", ) - - ctx.send_event(ChatRequestEvent(model="math_chat_model", messages=[msg])) + ctx.send_event(ChatRequestEvent(model="math_chat_model", messages=[msg])) @action(ChatResponseEvent.EVENT_TYPE) @staticmethod diff --git a/python/flink_agents/examples/quickstart/agents/product_suggestion_agent.py b/python/flink_agents/examples/quickstart/agents/product_suggestion_agent.py index 80ddcc7bf..2ace17dd1 100644 --- a/python/flink_agents/examples/quickstart/agents/product_suggestion_agent.py +++ b/python/flink_agents/examples/quickstart/agents/product_suggestion_agent.py @@ -83,9 +83,8 @@ def process_input(event: Event, ctx: RunnerContext) -> None: ctx.send_event( ChatRequestEvent( model="generate_suggestion_model", - messages=[ - ChatMessage(role=MessageRole.USER, extra_args={"input": content}) - ], + messages=[ChatMessage(role=MessageRole.USER)], + arguments={"input": content}, ) ) diff --git a/python/flink_agents/examples/quickstart/agents/review_analysis_agent.py b/python/flink_agents/examples/quickstart/agents/review_analysis_agent.py index 831b2a720..1a941fe72 100644 --- a/python/flink_agents/examples/quickstart/agents/review_analysis_agent.py +++ b/python/flink_agents/examples/quickstart/agents/review_analysis_agent.py @@ -95,8 +95,14 @@ def process_input(event: Event, ctx: RunnerContext) -> None: "id": {input.id}, "review": {input.review} """ - msg = ChatMessage(role=MessageRole.USER, extra_args={"input": content}) - ctx.send_event(ChatRequestEvent(model="review_analysis_model", messages=[msg])) + msg = ChatMessage(role=MessageRole.USER) + ctx.send_event( + ChatRequestEvent( + model="review_analysis_model", + messages=[msg], + arguments={"input": content}, + ) + ) @action(ChatResponseEvent.EVENT_TYPE) @staticmethod diff --git a/python/flink_agents/examples/quickstart/agents/table_review_analysis_agent.py b/python/flink_agents/examples/quickstart/agents/table_review_analysis_agent.py index 4c81cda92..cdda3d7ae 100644 --- a/python/flink_agents/examples/quickstart/agents/table_review_analysis_agent.py +++ b/python/flink_agents/examples/quickstart/agents/table_review_analysis_agent.py @@ -127,8 +127,14 @@ def process_input(event: Event, ctx: RunnerContext) -> None: "id": {product_id}, "review": {review_text} """ - msg = ChatMessage(role=MessageRole.USER, extra_args={"input": content}) - ctx.send_event(ChatRequestEvent(model="review_analysis_model", messages=[msg])) + msg = ChatMessage(role=MessageRole.USER) + ctx.send_event( + ChatRequestEvent( + model="review_analysis_model", + messages=[msg], + arguments={"input": content}, + ) + ) @action(ChatResponseEvent.EVENT_TYPE) @staticmethod diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index daff073dc..1251a4dbe 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -51,6 +51,7 @@ _TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT" _TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT" _RETRY_STATS_CONTEXT = "_RETRY_STATS_CONTEXT" +_ARGUMENTS = "arguments" _logger = logging.getLogger(__name__) @@ -93,6 +94,7 @@ def _save_tool_request_event_context( tool_request_event_id: UUID, initial_request_id: UUID, model: str, + arguments: Dict | None, output_schema: OutputSchema | None, ) -> None: """Save the context for a specific tool request event.""" @@ -100,6 +102,7 @@ def _save_tool_request_event_context( context[str(tool_request_event_id)] = { "initial_request_id": initial_request_id, "model": model, + _ARGUMENTS: arguments if arguments is not None else {}, "output_schema": output_schema, } sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context) @@ -193,6 +196,7 @@ def _handle_tool_calls( model: str, chat_model: "BaseChatModelSetup", messages: List[ChatMessage], + arguments: Dict | None, output_schema: OutputSchema | None, ctx: RunnerContext, ) -> None: @@ -214,6 +218,7 @@ def _handle_tool_calls( tool_request_event.id, initial_request_id, model, + arguments, output_schema, ) @@ -251,6 +256,7 @@ async def chat( initial_request_id: UUID, model: str, messages: List[ChatMessage], + arguments: Dict | None, output_schema: OutputSchema | None, ctx: RunnerContext, ) -> None: @@ -290,9 +296,13 @@ async def chat( for attempt in range(num_retries + 1): try: if chat_async: - response = await ctx.durable_execute_async(chat_model.chat, messages) + response = await ctx.durable_execute_async( + chat_model.chat, messages, arguments=arguments + ) else: - response = ctx.durable_execute(chat_model.chat, messages) + response = ctx.durable_execute( + chat_model.chat, messages, arguments=arguments + ) if ( response.extra_args.get("model_name") @@ -351,6 +361,7 @@ async def chat( model, chat_model, messages, + arguments, output_schema, ctx, ) @@ -379,6 +390,7 @@ async def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> initial_request_id=event.id, model=event.model, messages=event.messages, + arguments=event.arguments, output_schema=event.output_schema, ctx=ctx, ) @@ -416,6 +428,7 @@ async def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) - initial_request_id=initial_request_id, model=tool_request_event_context["model"], messages=messages, + arguments=tool_request_event_context.get(_ARGUMENTS, {}), output_schema=tool_request_event_context["output_schema"], ctx=ctx, ) diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py index 922a98b3a..95de9c2ee 100644 --- a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py +++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py @@ -31,8 +31,12 @@ ErrorHandlingStrategy, ) from flink_agents.api.events.chat_event import ChatResponseEvent +from flink_agents.api.events.tool_event import ToolResponseEvent from flink_agents.api.metric_group import Counter, MetricGroup -from flink_agents.plan.actions.chat_model_action import chat +from flink_agents.plan.actions.chat_model_action import ( + chat, + process_chat_request_or_tool_response, +) # ============================================================================ # Mock infrastructure @@ -157,6 +161,7 @@ def test_chat_succeeds_without_retry(self) -> None: request_id, chat_model.connection, [ChatMessage(role=MessageRole.USER, content="hi")], + {}, None, ctx, ) @@ -197,6 +202,7 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: request_id, "test-model", [ChatMessage(role=MessageRole.USER, content="hi")], + {}, None, ctx, ) @@ -232,6 +238,7 @@ def test_chat_exhausts_retries_and_raises(self) -> None: request_id, "test-model", [ChatMessage(role=MessageRole.USER, content="hi")], + {}, None, ctx, ) @@ -270,3 +277,67 @@ class TestRetryWaitIntervalConfig: def test_default_value(self) -> None: """Default value is 1 second.""" assert AgentExecutionOptions.RETRY_WAIT_INTERVAL.get_default_value() == 1 + + +class TestProcessToolResponseArgumentsForwarding: + """Locks the contract that `_process_tool_response` forwards the saved + `arguments` from the tool-request-event context into the round-2 call + to `chat_model.chat(...)`. + """ + + def test_forwards_saved_arguments_to_chat(self) -> None: + initial_request_id = uuid4() + tool_request_event_id = uuid4() + tool_call_id = "call-1" + saved_arguments = {"k": "v"} + + captured_arguments: list[dict] = [] + + def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + captured_arguments.append(kwargs.get("arguments")) + return ChatMessage(role=MessageRole.ASSISTANT, content="done") + + chat_model = MagicMock() + chat_model.chat = mock_chat + + ctx, sent_events, _, sensory_memory = _create_mock_runner_context( + chat_model, max_retries=0, retry_wait_interval_sec=0 + ) + + # Pre-seed the tool-request-event context with saved arguments so + # _process_tool_response can look them up. + sensory_memory.set( + "_TOOL_REQUEST_EVENT_CONTEXT", + { + str(tool_request_event_id): { + "initial_request_id": initial_request_id, + "model": "test-model", + "arguments": saved_arguments, + "output_schema": None, + } + }, + ) + + # Pre-seed the tool-call context with prior messages so + # _update_tool_call_context can extend them with the tool response. + sensory_memory.set( + "_TOOL_CALL_CONTEXT", + { + str(initial_request_id): [ + ChatMessage(role=MessageRole.USER, content="hi") + ] + }, + ) + + tool_response_event = ToolResponseEvent( + request_id=tool_request_event_id, + responses={tool_call_id: "42"}, + external_ids={}, + ) + + asyncio.run(process_chat_request_or_tool_response(tool_response_event, ctx)) + + assert len(captured_arguments) == 1 + assert captured_arguments[0] == saved_arguments + assert len(sent_events) == 1 + assert isinstance(sent_events[0], ChatResponseEvent) diff --git a/python/flink_agents/runtime/java/java_chat_model.py b/python/flink_agents/runtime/java/java_chat_model.py index 7160d5315..4609d16c1 100644 --- a/python/flink_agents/runtime/java/java_chat_model.py +++ b/python/flink_agents/runtime/java/java_chat_model.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Dict, List, Sequence +from typing import Any, Dict, List, Mapping, Sequence from typing_extensions import override @@ -130,17 +130,24 @@ def open(self) -> None: self._j_resource.open() @override - def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + def chat( + self, + messages: Sequence[ChatMessage], + arguments: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> ChatMessage: """Execute chat conversation by delegating to Java implementation. 1. Convert Python messages to Java format - 2. Call Java chat method + 2. Call Java chat method with prompt-template arguments 3. Convert Java response back to Python format Parameters ---------- messages : Sequence[ChatMessage] Input message sequence + arguments : Mapping[str, Any] | None + Prompt-template variables forwarded to the Java setup. **kwargs : Any Additional parameters passed to the model service @@ -154,7 +161,9 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: self._j_resource_adapter.fromPythonChatMessage(message) for message in messages ] - j_response_message = self._j_resource.chat(java_messages, kwargs) + j_response_message = self._j_resource.chat( + java_messages, arguments or {}, kwargs + ) # Convert Java response back to Python format from flink_agents.runtime.python_java_utils import ( diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py b/python/flink_agents/runtime/tests/test_built_in_actions.py index 05a918813..72d0ebc01 100644 --- a/python/flink_agents/runtime/tests/test_built_in_actions.py +++ b/python/flink_agents/runtime/tests/test_built_in_actions.py @@ -81,7 +81,12 @@ def model_kwargs(self) -> Dict[str, Any]: """Return model kwargs.""" return {} - def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + def chat( + self, + messages: Sequence[ChatMessage], + arguments: Dict[str, Any] | None = None, + **kwargs: Any, + ) -> ChatMessage: """Execute chat conversation.""" # Get model connection server = self.resource_context.get_resource( @@ -99,12 +104,10 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: prompt = self.prompt if "sum" in messages[-1].content: - input_variable = {} - for msg in messages: - # Convert Any values to str to match format_messages signature - str_extra_args = {k: str(v) for k, v in msg.extra_args.items()} - input_variable.update(str_extra_args) - messages = prompt.format_messages(**input_variable) + str_arguments = ( + {k: str(v) for k, v in arguments.items()} if arguments else {} + ) + messages = prompt.format_messages(**str_arguments) # Bind tools tools = None @@ -179,11 +182,8 @@ def process_input(event: Event, ctx: RunnerContext) -> None: ctx.send_event( ChatRequestEvent( model="mock_chat_model", - messages=[ - ChatMessage( - role=MessageRole.USER, content=input, extra_args={"task": input} - ) - ], + messages=[ChatMessage(role=MessageRole.USER, content=input)], + arguments={"task": input}, ) ) From 9d7e438eec47c00ec0030e2604ef5add24ec05ef Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Thu, 21 May 2026 23:28:35 -0700 Subject: [PATCH 2/3] [ci] Re-trigger CI (flaky it-python on python-3.12+flink-2.1) From d64ff36469ad7f837bbc9afc4652ebc0cdda1790 Mon Sep 17 00:00:00 2001 From: Weiqing Yang Date: Sat, 23 May 2026 22:48:16 -0700 Subject: [PATCH 3/3] [ci] Re-trigger CI (flaky cross-language ollama-cpu timeout)