diff --git a/api/src/main/java/org/apache/flink/agents/api/Event.java b/api/src/main/java/org/apache/flink/agents/api/Event.java index 7bc006257..e7fbde464 100644 --- a/api/src/main/java/org/apache/flink/agents/api/Event.java +++ b/api/src/main/java/org/apache/flink/agents/api/Event.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; @@ -88,14 +89,17 @@ public void setAttr(String name, Object value) { attributes.put(name, value); } + @JsonIgnore public boolean hasSourceTimestamp() { return sourceTimestamp != null; } + @JsonIgnore public Long getSourceTimestamp() { return sourceTimestamp; } + @JsonIgnore public void setSourceTimestamp(long timestamp) { this.sourceTimestamp = timestamp; } diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/Action.java b/api/src/main/java/org/apache/flink/agents/api/annotation/Action.java index 99218739d..fd7f62c92 100644 --- a/api/src/main/java/org/apache/flink/agents/api/annotation/Action.java +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/Action.java @@ -42,6 +42,19 @@ * @Action(listenEventTypes = {InputEvent.EVENT_TYPE, "MyCustomEvent"}) * public void handleMultiple(Event event, RunnerContext ctx) { ... } * } + * + *

For a cross-language action, set {@link #target()} to a {@link PythonFunction} with a + * non-empty {@code module}. The annotated Java body is never invoked — throw {@link + * UnsupportedOperationException} so direct calls outside the framework fail loud: + * + *

{@code
+ * @Action(
+ *     listenEventTypes = {InputEvent.EVENT_TYPE},
+ *     target = @PythonFunction(module = "my_pkg.handlers", qualname = "handle_input"))
+ * public void handleInput(Event event, RunnerContext ctx) {
+ *     throw new UnsupportedOperationException("cross-language stub");
+ * }
+ * }
*/ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @@ -52,4 +65,11 @@ * @return Array of event type strings */ String[] listenEventTypes(); + + /** + * Cross-language target. When {@link PythonFunction#module()} is non-empty, dispatch routes to + * the Python target and the annotated Java body is unused. Default (empty {@code module}) keeps + * the action native Java. + */ + PythonFunction target() default @PythonFunction; } diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/PythonFunction.java b/api/src/main/java/org/apache/flink/agents/api/annotation/PythonFunction.java new file mode 100644 index 000000000..424b1077f --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/PythonFunction.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +/** + * Python target descriptor used inside {@link Action#target()}. Empty {@link #module()} = no + * cross-language target (action stays native Java). When non-empty, the Java method body is never + * invoked — throw {@link UnsupportedOperationException} from the stub so direct calls outside the + * framework fail loud. + */ +public @interface PythonFunction { + String module() default ""; + + String qualname() default ""; +} diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java index c3be5ef84..a15f220ce 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/messages/ChatMessage.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api.chat.messages; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; import java.util.HashMap; @@ -34,7 +35,11 @@ public class ChatMessage { private MessageRole role; private String content; + + @JsonProperty("tool_calls") private List> toolCalls; + + @JsonProperty("extra_args") private Map extraArgs; /** Default constructor with SYSTEM role */ @@ -83,18 +88,22 @@ public void setContent(String content) { this.content = content; } + @JsonProperty("tool_calls") public List> getToolCalls() { return toolCalls; } + @JsonProperty("tool_calls") public void setToolCalls(List> toolCalls) { this.toolCalls = toolCalls; } + @JsonProperty("extra_args") public Map getExtraArgs() { return extraArgs; } + @JsonProperty("extra_args") public void setExtraArgs(Map extraArgs) { this.extraArgs = extraArgs != null ? extraArgs : new HashMap<>(); } @@ -104,6 +113,7 @@ public String getText() { return this.content; } + @JsonIgnore public Map getMetadata() { return this.extraArgs; } diff --git a/api/src/main/java/org/apache/flink/agents/api/chat/messages/MessageRole.java b/api/src/main/java/org/apache/flink/agents/api/chat/messages/MessageRole.java index b992b9acc..9dc31002a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/chat/messages/MessageRole.java +++ b/api/src/main/java/org/apache/flink/agents/api/chat/messages/MessageRole.java @@ -18,6 +18,9 @@ package org.apache.flink.agents.api.chat.messages; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + /** * Enumeration of message roles in a chat conversation. Each role represents a different participant * type in the conversation. @@ -42,6 +45,7 @@ public enum MessageRole { this.value = value; } + @JsonCreator public static MessageRole fromValue(String value) { for (MessageRole messageRole : MessageRole.values()) { if (messageRole.getValue().equals(value)) { @@ -51,6 +55,7 @@ public static MessageRole fromValue(String value) { throw new IllegalArgumentException("Invalid MessageRole value: " + value); } + @JsonValue public String getValue() { return this.value; } diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java index 3e2c91aef..498ea3f07 100644 --- a/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java +++ b/api/src/main/java/org/apache/flink/agents/api/event/ToolResponseEvent.java @@ -83,6 +83,8 @@ public static ToolResponseEvent fromEvent(Event event) { responses.put(entry.getKey(), (ToolResponse) v); } else if (v instanceof Map) { responses.put(entry.getKey(), MAPPER.convertValue(v, ToolResponse.class)); + } else { + responses.put(entry.getKey(), ToolResponse.success(v)); } } attrs.put("responses", responses); diff --git a/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java b/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java index 0eab339c6..f4f2d2dd2 100644 --- a/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java +++ b/api/src/main/java/org/apache/flink/agents/api/tools/ToolResponse.java @@ -18,6 +18,10 @@ package org.apache.flink.agents.api.tools; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; + import java.util.Objects; /** @@ -29,11 +33,20 @@ public class ToolResponse { private final Object result; private final boolean success; private final String error; + + @JsonProperty("execution_time_ms") private final long executionTimeMs; + + @JsonProperty("tool_name") private final String toolName; + @JsonCreator private ToolResponse( - Object result, boolean success, String error, long executionTimeMs, String toolName) { + @JsonProperty("result") Object result, + @JsonProperty("success") boolean success, + @JsonProperty("error") String error, + @JsonProperty("execution_time_ms") long executionTimeMs, + @JsonProperty("tool_name") String toolName) { this.result = result; this.success = success; this.error = error; @@ -148,6 +161,7 @@ public String getToolName() { } /** Get the result as a string representation. */ + @JsonIgnore public String getResultAsString() { if (result == null) { return null; diff --git a/api/src/main/java/org/apache/flink/agents/api/vectorstores/Document.java b/api/src/main/java/org/apache/flink/agents/api/vectorstores/Document.java index fc1d1d407..689df530d 100644 --- a/api/src/main/java/org/apache/flink/agents/api/vectorstores/Document.java +++ b/api/src/main/java/org/apache/flink/agents/api/vectorstores/Document.java @@ -18,6 +18,9 @@ package org.apache.flink.agents.api.vectorstores; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + import javax.annotation.Nullable; import java.util.Arrays; @@ -61,12 +64,13 @@ public Document( this(content, metadata, id, embedding, null); } + @JsonCreator public Document( - String content, - Map metadata, - String id, - @Nullable float[] embedding, - @Nullable Float score) { + @JsonProperty("content") String content, + @JsonProperty("metadata") Map metadata, + @JsonProperty("id") String id, + @JsonProperty("embedding") @Nullable float[] embedding, + @JsonProperty("score") @Nullable Float score) { this.content = content; this.metadata = metadata; this.id = id; diff --git a/api/src/test/java/org/apache/flink/agents/api/CrossLanguageEventSnapshotTest.java b/api/src/test/java/org/apache/flink/agents/api/CrossLanguageEventSnapshotTest.java new file mode 100644 index 000000000..0398040fd --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/CrossLanguageEventSnapshotTest.java @@ -0,0 +1,520 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.agents.OutputSchema; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.event.ChatRequestEvent; +import org.apache.flink.agents.api.event.ChatResponseEvent; +import org.apache.flink.agents.api.event.ContextRetrievalRequestEvent; +import org.apache.flink.agents.api.event.ContextRetrievalResponseEvent; +import org.apache.flink.agents.api.event.ToolRequestEvent; +import org.apache.flink.agents.api.event.ToolResponseEvent; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** Cross-language event SerDe snapshot tests. */ +class CrossLanguageEventSnapshotTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private static final UUID FIXED_EVENT_ID = + UUID.fromString("00000000-0000-0000-0000-000000000001"); + private static final UUID FIXED_REQUEST_ID = + UUID.fromString("00000000-0000-0000-0000-000000000002"); + private static final String FIXED_TOOL_CALL_ID = "call_aaaa"; + private static final String FIXED_TOOL_CALL_ID_NUMERIC = "call_bbbb"; + private static final String FIXED_TOOL_CALL_ID_BOOL = "call_cccc"; + private static final long FIXED_TIMESTAMP = 1_700_000_000_000L; + + private static Path snapshotDir; + + @BeforeAll + static void resolveSnapshotDir() { + Path repoRoot = Paths.get(System.getProperty("user.dir")).getParent(); + snapshotDir = repoRoot.resolve("e2e-test/cross-language-event-snapshots"); + } + + // ── Helpers ──────────────────────────────────────────────────────────── + + private static boolean regenerateRequested() { + return Boolean.parseBoolean(System.getProperty("regenerate.snapshots", "false")); + } + + private static void writeJavaSnapshot(String fileName, Event event) throws Exception { + String json = MAPPER.writerWithDefaultPrettyPrinter().writeValueAsString(event); + Path target = snapshotDir.resolve("java/" + fileName); + Files.createDirectories(target.getParent()); + Files.writeString(target, json + "\n"); + } + + private static void assertJavaSnapshotStable(String fileName, Event event) throws Exception { + String actualJson = MAPPER.writeValueAsString(event); + JsonNode actual = MAPPER.readTree(actualJson); + + Path committed = snapshotDir.resolve("java/" + fileName); + assertTrue( + Files.exists(committed), + "Java snapshot " + + fileName + + " missing from " + + committed + + ". If you added a new event, regenerate with -Dregenerate.snapshots=true and commit alongside the test."); + JsonNode expected = MAPPER.readTree(Files.readString(committed)); + + assertEquals( + expected, + actual, + "Java serialization of " + + fileName + + " drifted from committed snapshot; if intentional, regenerate."); + } + + private static Event readPythonSnapshot(String fileName) throws Exception { + Path pythonSnapshot = snapshotDir.resolve("python/" + fileName); + assertTrue( + Files.exists(pythonSnapshot), + "Python snapshot " + + fileName + + " missing from " + + pythonSnapshot + + ". Regenerate the Python side with REGENERATE_SNAPSHOTS=1 and commit alongside this test."); + return Event.fromJson(Files.readString(pythonSnapshot)); + } + + // ── InputEvent ───────────────────────────────────────────────────────── + + private static InputEvent buildInputEvent() { + Map attrs = new HashMap<>(); + attrs.put("input", "hello"); + return new InputEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateInputEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("input_event.json", buildInputEvent()); + } + + @Test + void inputEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("input_event.json", buildInputEvent()); + } + + @Test + void javaCanDeserializeInputEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("input_event.json"); + InputEvent typed = InputEvent.fromEvent(base); + + assertEquals( + FIXED_EVENT_ID, typed.getId(), "ID lost when deserializing Python InputEvent."); + assertEquals(InputEvent.EVENT_TYPE, typed.getType()); + assertEquals("hello", typed.getInput(), "InputEvent.input mismatch."); + } + + // ── OutputEvent ──────────────────────────────────────────────────────── + + private static OutputEvent buildOutputEvent() { + Map attrs = new HashMap<>(); + attrs.put("output", "world"); + return new OutputEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateOutputEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("output_event.json", buildOutputEvent()); + } + + @Test + void outputEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("output_event.json", buildOutputEvent()); + } + + @Test + void javaCanDeserializeOutputEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("output_event.json"); + OutputEvent typed = OutputEvent.fromEvent(base); + + assertEquals( + FIXED_EVENT_ID, typed.getId(), "ID lost when deserializing Python OutputEvent."); + assertEquals(OutputEvent.EVENT_TYPE, typed.getType()); + assertEquals("world", typed.getOutput(), "OutputEvent.output mismatch."); + } + + // ── ChatRequestEvent ─────────────────────────────────────────────────── + + private static ChatRequestEvent buildChatRequestEvent() { + Map attrs = new LinkedHashMap<>(); + attrs.put("model", "test-model"); + attrs.put("messages", List.of(new ChatMessage(MessageRole.USER, "hello world"))); + return new ChatRequestEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateChatRequestEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("chat_request_event.json", buildChatRequestEvent()); + } + + @Test + void chatRequestEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("chat_request_event.json", buildChatRequestEvent()); + } + + @Test + void javaCanDeserializeChatRequestEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("chat_request_event.json"); + ChatRequestEvent typed = ChatRequestEvent.fromEvent(base); + + assertEquals(FIXED_EVENT_ID, typed.getId()); + assertEquals(ChatRequestEvent.EVENT_TYPE, typed.getType()); + assertEquals("test-model", typed.getModel()); + assertNotNull(typed.getMessages()); + assertEquals(1, typed.getMessages().size(), "Expected one message."); + ChatMessage msg = typed.getMessages().get(0); + assertEquals(MessageRole.USER, msg.getRole(), "Role mismatch on Python-produced message."); + assertEquals("hello world", msg.getContent()); + } + + /** + * Known 0.3 gap — {@link RowTypeInfo}-typed {@code output_schema} does not round-trip across + * the language boundary. Java emits {@code {"fieldNames": [...], "types": []}} while + * Python emits {@code {"names": [...], "types": []}}, so a {@link + * ChatRequestEvent} carrying a {@code RowTypeInfo} schema cannot be deserialized on the other + * side. The {@code BaseModel} (Pydantic class) branch is symmetric and works. Reconciling the + * {@code RowTypeInfo} wire format requires a canonical shape + bilateral {@code OutputSchema} + * serdes shims; tracked as a follow-up. + */ + @Test + void chatRequestRowTypeInfoOutputSchemaIsNotPortableAcrossLanguages_knownGap() + throws Exception { + OutputSchema schema = + new OutputSchema( + new RowTypeInfo( + new TypeInformation[] {BasicTypeInfo.STRING_TYPE_INFO}, + new String[] {"name"})); + ChatRequestEvent event = + new ChatRequestEvent( + "test-model", List.of(new ChatMessage(MessageRole.USER, "hi")), schema); + String json = MAPPER.writeValueAsString(event); + + // Pin Java's local shape so a future regression can't silently change it. The gap with + // Python's `{"names": ...}` shape is the documented limitation, not the assertion. + assertTrue(json.contains("\"fieldNames\""), "Java wire format uses `fieldNames`."); + assertFalse(json.contains("\"names\""), "Java wire format does not use Python's `names`."); + } + + // ── ChatResponseEvent ────────────────────────────────────────────────── + + private static ChatResponseEvent buildChatResponseEvent() { + Map attrs = new LinkedHashMap<>(); + attrs.put("request_id", FIXED_REQUEST_ID); + attrs.put("response", new ChatMessage(MessageRole.ASSISTANT, "hi there")); + attrs.put("retry_count", 0); + attrs.put("total_retry_wait_sec", 0); + return new ChatResponseEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateChatResponseEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("chat_response_event.json", buildChatResponseEvent()); + } + + @Test + void chatResponseEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("chat_response_event.json", buildChatResponseEvent()); + } + + @Test + void javaCanDeserializeChatResponseEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("chat_response_event.json"); + ChatResponseEvent typed = ChatResponseEvent.fromEvent(base); + + assertEquals(FIXED_EVENT_ID, typed.getId()); + assertEquals(ChatResponseEvent.EVENT_TYPE, typed.getType()); + assertEquals(FIXED_REQUEST_ID, typed.getRequestId(), "request_id mismatch."); + ChatMessage response = typed.getResponse(); + assertNotNull(response, "response field is null."); + assertEquals(MessageRole.ASSISTANT, response.getRole(), "Role mismatch on response."); + assertEquals("hi there", response.getContent()); + } + + // ── ToolRequestEvent ─────────────────────────────────────────────────── + + private static ToolRequestEvent buildToolRequestEvent() { + Map toolCall = new LinkedHashMap<>(); + toolCall.put("id", FIXED_TOOL_CALL_ID); + toolCall.put("name", "echo"); + toolCall.put("arguments", Map.of("value", "ping")); + + Map attrs = new LinkedHashMap<>(); + attrs.put("model", "test-model"); + attrs.put("tool_calls", List.of(toolCall)); + return new ToolRequestEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateToolRequestEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("tool_request_event.json", buildToolRequestEvent()); + } + + @Test + void toolRequestEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("tool_request_event.json", buildToolRequestEvent()); + } + + @Test + void javaCanDeserializeToolRequestEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("tool_request_event.json"); + ToolRequestEvent typed = ToolRequestEvent.fromEvent(base); + + assertEquals(FIXED_EVENT_ID, typed.getId()); + assertEquals(ToolRequestEvent.EVENT_TYPE, typed.getType()); + assertEquals("test-model", typed.getModel()); + List> toolCalls = typed.getToolCalls(); + assertNotNull(toolCalls); + assertEquals(1, toolCalls.size()); + assertEquals(FIXED_TOOL_CALL_ID, toolCalls.get(0).get("id")); + } + + // ── ToolResponseEvent ────────────────────────────────────────────────── + + private static ToolResponseEvent buildToolResponseEvent() { + Map attrs = new LinkedHashMap<>(); + attrs.put("request_id", FIXED_REQUEST_ID); + attrs.put("responses", Map.of(FIXED_TOOL_CALL_ID, ToolResponse.success("pong"))); + attrs.put("success", Map.of(FIXED_TOOL_CALL_ID, true)); + attrs.put("error", new HashMap()); + attrs.put("external_ids", new HashMap()); + attrs.put("timestamp", FIXED_TIMESTAMP); + return new ToolResponseEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateToolResponseEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("tool_response_event.json", buildToolResponseEvent()); + } + + @Test + void toolResponseEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("tool_response_event.json", buildToolResponseEvent()); + } + + @Test + void pythonToolResponseEventRoundTripsScalarResponses() throws Exception { + Event base = readPythonSnapshot("tool_response_event.json"); + ToolResponseEvent typed = ToolResponseEvent.fromEvent(base); + + assertEquals(FIXED_REQUEST_ID, typed.getRequestId()); + + ToolResponse stringResp = typed.getResponses().get(FIXED_TOOL_CALL_ID); + assertNotNull(stringResp, "String response should be wrapped as ToolResponse.success."); + assertEquals("pong", stringResp.getResult()); + assertTrue(stringResp.isSuccess()); + + ToolResponse numericResp = typed.getResponses().get(FIXED_TOOL_CALL_ID_NUMERIC); + assertNotNull(numericResp, "Number response should be wrapped as ToolResponse.success."); + assertEquals(42, ((Number) numericResp.getResult()).intValue()); + assertTrue(numericResp.isSuccess()); + + ToolResponse boolResp = typed.getResponses().get(FIXED_TOOL_CALL_ID_BOOL); + assertNotNull(boolResp, "Boolean response should be wrapped as ToolResponse.success."); + assertEquals(Boolean.TRUE, boolResp.getResult()); + assertTrue(boolResp.isSuccess()); + + // Remaining shape gap: Python's ToolResponseEvent model has no top-level success/error/ + // timestamp maps (those live inside each ToolResponse on the Java side). Pin it so the + // divergence stays visible. + Map attrs = typed.getAttributes(); + assertFalse(attrs.containsKey("success")); + assertFalse(attrs.containsKey("error")); + assertFalse(attrs.containsKey("timestamp")); + } + + // ── ContextRetrievalRequestEvent ─────────────────────────────────────── + + private static ContextRetrievalRequestEvent buildContextRetrievalRequestEvent() { + Map attrs = new LinkedHashMap<>(); + attrs.put("query", "what is flink"); + attrs.put("vector_store", "test-store"); + attrs.put("max_results", 5); + return new ContextRetrievalRequestEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateContextRetrievalRequestEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot( + "context_retrieval_request_event.json", buildContextRetrievalRequestEvent()); + } + + @Test + void contextRetrievalRequestEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable( + "context_retrieval_request_event.json", buildContextRetrievalRequestEvent()); + } + + @Test + void javaCanDeserializeContextRetrievalRequestEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("context_retrieval_request_event.json"); + ContextRetrievalRequestEvent typed = ContextRetrievalRequestEvent.fromEvent(base); + + assertEquals(FIXED_EVENT_ID, typed.getId()); + assertEquals(ContextRetrievalRequestEvent.EVENT_TYPE, typed.getType()); + assertEquals("what is flink", typed.getQuery()); + assertEquals("test-store", typed.getVectorStore()); + assertEquals(5, typed.getMaxResults()); + } + + // ── ContextRetrievalResponseEvent ────────────────────────────────────── + + private static ContextRetrievalResponseEvent buildContextRetrievalResponseEvent() { + Document doc = new Document("doc content", new LinkedHashMap<>(Map.of("k", "v")), "doc-1"); + Map attrs = new LinkedHashMap<>(); + attrs.put("request_id", FIXED_REQUEST_ID); + attrs.put("query", "what is flink"); + attrs.put("documents", new ArrayList<>(List.of(doc))); + return new ContextRetrievalResponseEvent(FIXED_EVENT_ID, attrs); + } + + @Test + void regenerateContextRetrievalResponseEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot( + "context_retrieval_response_event.json", buildContextRetrievalResponseEvent()); + } + + @Test + void contextRetrievalResponseEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable( + "context_retrieval_response_event.json", buildContextRetrievalResponseEvent()); + } + + @Test + void javaCanDeserializeContextRetrievalResponseEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("context_retrieval_response_event.json"); + ContextRetrievalResponseEvent typed = ContextRetrievalResponseEvent.fromEvent(base); + + assertEquals(FIXED_EVENT_ID, typed.getId()); + assertEquals(ContextRetrievalResponseEvent.EVENT_TYPE, typed.getType()); + assertEquals(FIXED_REQUEST_ID, typed.getRequestId()); + assertEquals("what is flink", typed.getQuery()); + List docs = typed.getDocuments(); + assertNotNull(docs); + assertEquals(1, docs.size()); + assertEquals("doc content", docs.get(0).getContent()); + assertEquals("doc-1", docs.get(0).getId()); + } + + // ── Generic Event with primitive attributes (user-authored axis) ─────── + + private static final String GENERIC_EVENT_TYPE = "_my_custom_event"; + + private static Event buildGenericEvent() { + Map attrs = new LinkedHashMap<>(); + attrs.put("k_int", 42); + attrs.put("k_float", 1.5); + attrs.put("k_bool", true); + attrs.put("k_str", "hello"); + attrs.put("k_null", null); + attrs.put("k_list", List.of(1, 2, 3)); + attrs.put("k_dict", Map.of("nested", "value")); + return new Event(FIXED_EVENT_ID, GENERIC_EVENT_TYPE, attrs); + } + + @Test + void regenerateGenericEventJavaSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + writeJavaSnapshot("generic_event_with_attrs.json", buildGenericEvent()); + } + + @Test + void genericEventJavaSnapshotIsStable() throws Exception { + assertJavaSnapshotStable("generic_event_with_attrs.json", buildGenericEvent()); + } + + @Test + void javaCanDeserializeGenericEventFromPythonSnapshot() throws Exception { + Event base = readPythonSnapshot("generic_event_with_attrs.json"); + + assertEquals(GENERIC_EVENT_TYPE, base.getType()); + Map attrs = base.getAttributes(); + assertEquals(42, attrs.get("k_int")); + assertTrue(attrs.get("k_int") instanceof Integer); + assertEquals(1.5, attrs.get("k_float")); + assertTrue(attrs.get("k_float") instanceof Double); + assertEquals(true, attrs.get("k_bool")); + assertEquals("hello", attrs.get("k_str")); + assertTrue(attrs.containsKey("k_null")); + assertEquals(null, attrs.get("k_null")); + assertEquals(List.of(1, 2, 3), attrs.get("k_list")); + assertEquals(Map.of("nested", "value"), attrs.get("k_dict")); + } + + // ── Python-only subclass with no Java counterpart (graceful fallback) ── + + @Test + void javaCanDeserializePythonOnlySubclassEventAsBaseEvent() throws Exception { + Event base = readPythonSnapshot("python_only_subclass_event.json"); + + assertEquals(Event.class, base.getClass()); + assertEquals("_my_python_only_event", base.getType()); + assertEquals(FIXED_EVENT_ID, base.getId()); + + Map attrs = base.getAttributes(); + assertEquals("ping", attrs.get("value")); + assertEquals(7, attrs.get("count")); + } + + // ── Smoke ────────────────────────────────────────────────────────────── + + @Test + void snapshotDirectoryExists() { + assertNotNull(snapshotDir); + assertTrue(Files.isDirectory(snapshotDir), "Expected snapshot directory at " + snapshotDir); + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java b/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java index fd906bd51..46d7ea0f9 100644 --- a/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/agents/AgentAddActionTest.java @@ -76,4 +76,43 @@ void duplicateNameRejected() { .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("act"); } + + @Test + void javaFunctionDescriptorStoredAsIs() { + Agent agent = new Agent(); + JavaFunction jf = + new JavaFunction( + "com.example.Handlers", + "handle", + java.util.List.of( + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext")); + + agent.addAction("act", new String[] {"_input_event"}, jf, null); + + Tuple3> entry = agent.getActions().get("act"); + assertThat(entry).isNotNull(); + assertThat(entry.f1).isSameAs(jf); + } + + @Test + void duplicateNameRejectedForJavaFunctionDescriptor() { + Agent agent = new Agent(); + JavaFunction jf = + new JavaFunction("com.example.X", "m", java.util.List.of("java.lang.String")); + agent.addAction("act", new String[] {"_input_event"}, jf, null); + + assertThatThrownBy(() -> agent.addAction("act", new String[] {"_input_event"}, jf, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("act"); + } + + @Test + void addActionReturnsSelfForChaining() { + Agent agent = new Agent(); + Agent returned = + agent.addAction( + "act", new String[] {"_input_event"}, new PythonFunction("p", "q"), null); + assertThat(returned).isSameAs(agent); + } } diff --git a/e2e-test/cross-language-agent-plan-snapshots/java/agent_plan_with_python_action.json b/e2e-test/cross-language-agent-plan-snapshots/java/agent_plan_with_python_action.json new file mode 100644 index 000000000..fd662469f --- /dev/null +++ b/e2e-test/cross-language-agent-plan-snapshots/java/agent_plan_with_python_action.json @@ -0,0 +1,58 @@ +{ + "actions" : { + "chat_model_action" : { + "name" : "chat_model_action", + "exec" : { + "func_type" : "JavaFunction", + "qualname" : "org.apache.flink.agents.plan.actions.ChatModelAction", + "method_name" : "processChatRequestOrToolResponse", + "parameter_types" : [ "org.apache.flink.agents.api.Event", "org.apache.flink.agents.api.context.RunnerContext" ] + }, + "listen_event_types" : [ "_chat_request_event", "_tool_response_event" ], + "config" : null + }, + "context_retrieval_action" : { + "name" : "context_retrieval_action", + "exec" : { + "func_type" : "JavaFunction", + "qualname" : "org.apache.flink.agents.plan.actions.ContextRetrievalAction", + "method_name" : "processContextRetrievalRequest", + "parameter_types" : [ "org.apache.flink.agents.api.Event", "org.apache.flink.agents.api.context.RunnerContext" ] + }, + "listen_event_types" : [ "_context_retrieval_request_event" ], + "config" : null + }, + "handle" : { + "name" : "handle", + "exec" : { + "func_type" : "PythonFunction", + "module" : "flink_agents.plan.tests.test_agent_plan_cross_language", + "qualname" : "_dummy_action" + }, + "listen_event_types" : [ "_input_event" ], + "config" : null + }, + "tool_call_action" : { + "name" : "tool_call_action", + "exec" : { + "func_type" : "JavaFunction", + "qualname" : "org.apache.flink.agents.plan.actions.ToolCallAction", + "method_name" : "processToolRequest", + "parameter_types" : [ "org.apache.flink.agents.api.Event", "org.apache.flink.agents.api.context.RunnerContext" ] + }, + "listen_event_types" : [ "_tool_request_event" ], + "config" : null + } + }, + "actions_by_event" : { + "_context_retrieval_request_event" : [ "context_retrieval_action" ], + "_tool_response_event" : [ "chat_model_action" ], + "_chat_request_event" : [ "chat_model_action" ], + "_tool_request_event" : [ "tool_call_action" ], + "_input_event" : [ "handle" ] + }, + "resource_providers" : { }, + "config" : { + "conf_data" : { } + } +} diff --git a/e2e-test/cross-language-agent-plan-snapshots/python/agent_plan_with_java_action.json b/e2e-test/cross-language-agent-plan-snapshots/python/agent_plan_with_java_action.json new file mode 100644 index 000000000..dda6405d1 --- /dev/null +++ b/e2e-test/cross-language-agent-plan-snapshots/python/agent_plan_with_java_action.json @@ -0,0 +1,78 @@ +{ + "actions": { + "handle": { + "name": "handle", + "exec": { + "func_type": "JavaFunction", + "qualname": "com.example.Handlers", + "method_name": "handle", + "parameter_types": [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext" + ] + }, + "listen_event_types": [ + "_input_event" + ], + "config": null + }, + "chat_model_action": { + "name": "chat_model_action", + "exec": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.actions.chat_model_action", + "qualname": "process_chat_request_or_tool_response" + }, + "listen_event_types": [ + "_chat_request_event", + "_tool_response_event" + ], + "config": null + }, + "tool_call_action": { + "name": "tool_call_action", + "exec": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.actions.tool_call_action", + "qualname": "process_tool_request" + }, + "listen_event_types": [ + "_tool_request_event" + ], + "config": null + }, + "context_retrieval_action": { + "name": "context_retrieval_action", + "exec": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.actions.context_retrieval_action", + "qualname": "process_context_retrieval_request" + }, + "listen_event_types": [ + "_context_retrieval_request_event" + ], + "config": null + } + }, + "actions_by_event": { + "_input_event": [ + "handle" + ], + "_chat_request_event": [ + "chat_model_action" + ], + "_tool_response_event": [ + "chat_model_action" + ], + "_tool_request_event": [ + "tool_call_action" + ], + "_context_retrieval_request_event": [ + "context_retrieval_action" + ] + }, + "resource_providers": {}, + "config": { + "conf_data": {} + } +} diff --git a/e2e-test/cross-language-event-snapshots/java/chat_request_event.json b/e2e-test/cross-language-event-snapshots/java/chat_request_event.json new file mode 100644 index 000000000..347c47e71 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/chat_request_event.json @@ -0,0 +1,13 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "model" : "test-model", + "messages" : [ { + "role" : "user", + "content" : "hello world", + "tool_calls" : [ ], + "extra_args" : { } + } ] + }, + "type" : "_chat_request_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/chat_response_event.json b/e2e-test/cross-language-event-snapshots/java/chat_response_event.json new file mode 100644 index 000000000..3d5b4793c --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/chat_response_event.json @@ -0,0 +1,15 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "request_id" : "00000000-0000-0000-0000-000000000002", + "response" : { + "role" : "assistant", + "content" : "hi there", + "tool_calls" : [ ], + "extra_args" : { } + }, + "retry_count" : 0, + "total_retry_wait_sec" : 0 + }, + "type" : "_chat_response_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/context_retrieval_request_event.json b/e2e-test/cross-language-event-snapshots/java/context_retrieval_request_event.json new file mode 100644 index 000000000..ead03f8de --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/context_retrieval_request_event.json @@ -0,0 +1,9 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "query" : "what is flink", + "vector_store" : "test-store", + "max_results" : 5 + }, + "type" : "_context_retrieval_request_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/context_retrieval_response_event.json b/e2e-test/cross-language-event-snapshots/java/context_retrieval_response_event.json new file mode 100644 index 000000000..90592d565 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/context_retrieval_response_event.json @@ -0,0 +1,17 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "request_id" : "00000000-0000-0000-0000-000000000002", + "query" : "what is flink", + "documents" : [ { + "content" : "doc content", + "metadata" : { + "k" : "v" + }, + "id" : "doc-1", + "embedding" : null, + "score" : null + } ] + }, + "type" : "_context_retrieval_response_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/generic_event_with_attrs.json b/e2e-test/cross-language-event-snapshots/java/generic_event_with_attrs.json new file mode 100644 index 000000000..96b9fc0f5 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/generic_event_with_attrs.json @@ -0,0 +1,15 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "type" : "_my_custom_event", + "attributes" : { + "k_int" : 42, + "k_float" : 1.5, + "k_bool" : true, + "k_str" : "hello", + "k_null" : null, + "k_list" : [ 1, 2, 3 ], + "k_dict" : { + "nested" : "value" + } + } +} diff --git a/e2e-test/cross-language-event-snapshots/java/input_event.json b/e2e-test/cross-language-event-snapshots/java/input_event.json new file mode 100644 index 000000000..8150a1ce6 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/input_event.json @@ -0,0 +1,7 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "input" : "hello" + }, + "type" : "_input_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/output_event.json b/e2e-test/cross-language-event-snapshots/java/output_event.json new file mode 100644 index 000000000..3fb4269e8 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/output_event.json @@ -0,0 +1,7 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "output" : "world" + }, + "type" : "_output_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/tool_request_event.json b/e2e-test/cross-language-event-snapshots/java/tool_request_event.json new file mode 100644 index 000000000..0f8ab2f3a --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/tool_request_event.json @@ -0,0 +1,14 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "model" : "test-model", + "tool_calls" : [ { + "id" : "call_aaaa", + "name" : "echo", + "arguments" : { + "value" : "ping" + } + } ] + }, + "type" : "_tool_request_event" +} diff --git a/e2e-test/cross-language-event-snapshots/java/tool_response_event.json b/e2e-test/cross-language-event-snapshots/java/tool_response_event.json new file mode 100644 index 000000000..04698abe2 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/java/tool_response_event.json @@ -0,0 +1,22 @@ +{ + "id" : "00000000-0000-0000-0000-000000000001", + "attributes" : { + "request_id" : "00000000-0000-0000-0000-000000000002", + "responses" : { + "call_aaaa" : { + "result" : "pong", + "success" : true, + "error" : null, + "execution_time_ms" : 0, + "tool_name" : null + } + }, + "success" : { + "call_aaaa" : true + }, + "error" : { }, + "external_ids" : { }, + "timestamp" : 1700000000000 + }, + "type" : "_tool_response_event" +} diff --git a/e2e-test/cross-language-event-snapshots/python/chat_request_event.json b/e2e-test/cross-language-event-snapshots/python/chat_request_event.json new file mode 100644 index 000000000..12900389c --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/chat_request_event.json @@ -0,0 +1,16 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_chat_request_event", + "attributes": { + "model": "test-model", + "messages": [ + { + "role": "user", + "content": "hello world", + "tool_calls": [], + "extra_args": {} + } + ], + "output_schema": null + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/chat_response_event.json b/e2e-test/cross-language-event-snapshots/python/chat_response_event.json new file mode 100644 index 000000000..bafb28116 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/chat_response_event.json @@ -0,0 +1,15 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_chat_response_event", + "attributes": { + "request_id": "00000000-0000-0000-0000-000000000002", + "response": { + "role": "assistant", + "content": "hi there", + "tool_calls": [], + "extra_args": {} + }, + "retry_count": 0, + "total_retry_wait_sec": 0 + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/context_retrieval_request_event.json b/e2e-test/cross-language-event-snapshots/python/context_retrieval_request_event.json new file mode 100644 index 000000000..357ce8bc9 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/context_retrieval_request_event.json @@ -0,0 +1,9 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_context_retrieval_request_event", + "attributes": { + "query": "what is flink", + "vector_store": "test-store", + "max_results": 5 + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/context_retrieval_response_event.json b/e2e-test/cross-language-event-snapshots/python/context_retrieval_response_event.json new file mode 100644 index 000000000..95e14f0a6 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/context_retrieval_response_event.json @@ -0,0 +1,19 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_context_retrieval_response_event", + "attributes": { + "request_id": "00000000-0000-0000-0000-000000000002", + "query": "what is flink", + "documents": [ + { + "content": "doc content", + "metadata": { + "k": "v" + }, + "id": "doc-1", + "embedding": null, + "score": null + } + ] + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/generic_event_with_attrs.json b/e2e-test/cross-language-event-snapshots/python/generic_event_with_attrs.json new file mode 100644 index 000000000..cfd461b36 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/generic_event_with_attrs.json @@ -0,0 +1,19 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_my_custom_event", + "attributes": { + "k_int": 42, + "k_float": 1.5, + "k_bool": true, + "k_str": "hello", + "k_null": null, + "k_list": [ + 1, + 2, + 3 + ], + "k_dict": { + "nested": "value" + } + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/input_event.json b/e2e-test/cross-language-event-snapshots/python/input_event.json new file mode 100644 index 000000000..db24e4c56 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/input_event.json @@ -0,0 +1,7 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_input_event", + "attributes": { + "input": "hello" + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/output_event.json b/e2e-test/cross-language-event-snapshots/python/output_event.json new file mode 100644 index 000000000..f4b48a746 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/output_event.json @@ -0,0 +1,7 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_output_event", + "attributes": { + "output": "world" + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/python_only_subclass_event.json b/e2e-test/cross-language-event-snapshots/python/python_only_subclass_event.json new file mode 100644 index 000000000..a48448c12 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/python_only_subclass_event.json @@ -0,0 +1,8 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_my_python_only_event", + "attributes": { + "value": "ping", + "count": 7 + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/tool_request_event.json b/e2e-test/cross-language-event-snapshots/python/tool_request_event.json new file mode 100644 index 000000000..2ac1fc511 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/tool_request_event.json @@ -0,0 +1,16 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_tool_request_event", + "attributes": { + "model": "test-model", + "tool_calls": [ + { + "id": "call_aaaa", + "name": "echo", + "arguments": { + "value": "ping" + } + } + ] + } +} diff --git a/e2e-test/cross-language-event-snapshots/python/tool_response_event.json b/e2e-test/cross-language-event-snapshots/python/tool_response_event.json new file mode 100644 index 000000000..6db77db18 --- /dev/null +++ b/e2e-test/cross-language-event-snapshots/python/tool_response_event.json @@ -0,0 +1,17 @@ +{ + "id": "00000000-0000-0000-0000-000000000001", + "type": "_tool_response_event", + "attributes": { + "request_id": "00000000-0000-0000-0000-000000000002", + "responses": { + "call_aaaa": "pong", + "call_bbbb": 42, + "call_cccc": true + }, + "external_ids": { + "call_aaaa": null, + "call_bbbb": null, + "call_cccc": null + } + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaActionHandler.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaActionHandler.java new file mode 100644 index 000000000..ac844d6a5 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaActionHandler.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.resource.test; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.context.RunnerContext; + +/** + * Java action referenced by the Python {@code PythonAgentWithJavaActionAgent}. Mirror of {@code + * python_action_handler.multiply_by_two} in the Java→Python direction. + */ +public final class JavaActionHandler { + + private JavaActionHandler() {} + + public static void multiplyByTwo(Event event, RunnerContext ctx) { + long value = ((Number) InputEvent.fromEvent(event).getInput()).longValue(); + ctx.sendEvent(new OutputEvent(value * 2)); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaAgentWithPythonActionAgent.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaAgentWithPythonActionAgent.java new file mode 100644 index 000000000..8be703347 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaAgentWithPythonActionAgent.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.resource.test; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.function.PythonFunction; +import org.apache.flink.api.java.functions.KeySelector; + +public class JavaAgentWithPythonActionAgent extends Agent { + + public static final String PYTHON_MODULE = + "flink_agents.e2e_tests.e2e_tests_resource_cross_language.python_action_handler"; + public static final String PYTHON_QUALNAME = "multiply_by_two"; + + public JavaAgentWithPythonActionAgent() { + addAction( + "multiply_by_two", + new String[] {InputEvent.EVENT_TYPE}, + new PythonFunction(PYTHON_MODULE, PYTHON_QUALNAME), + null); + } + + public static class SingleKeySelector implements KeySelector { + @Override + public Long getKey(Long value) { + return 0L; + } + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaAgentWithPythonActionTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaAgentWithPythonActionTest.java new file mode 100644 index 000000000..f505deffe --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/JavaAgentWithPythonActionTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.resource.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +public class JavaAgentWithPythonActionTest { + + @Test + public void javaAgentDispatchesPythonActionBody() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + DataStream inputStream = env.fromData(1L, 2L, 3L, 4L, 5L); + + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + + DataStream outputStream = + agentsEnv + .fromDataStream( + inputStream, new JavaAgentWithPythonActionAgent.SingleKeySelector()) + .apply(new JavaAgentWithPythonActionAgent()) + .toDataStream(); + + CloseableIterator results = outputStream.collectAsync(); + agentsEnv.execute(); + + List actual = new ArrayList<>(); + while (results.hasNext()) { + actual.add(((Number) results.next()).longValue()); + } + Collections.sort(actual); + + assertThat(actual).containsExactly(2L, 4L, 6L, 8L, 10L); + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index 605500ca1..1bd647764 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -227,21 +227,56 @@ private void extractActionsFromAgent(Agent agent) throws Exception { // Scan the agent class for methods annotated with @Action Class agentClass = agent.getClass(); + // getDeclaredMethods() skips inherited @Action methods; reject loudly. + for (Class parent = agentClass.getSuperclass(); + parent != null && parent != Agent.class; + parent = parent.getSuperclass()) { + for (Method inherited : parent.getDeclaredMethods()) { + if (inherited.isAnnotationPresent( + org.apache.flink.agents.api.annotation.Action.class)) { + throw new IllegalStateException( + "Inherited @Action '" + + parent.getName() + + "#" + + inherited.getName() + + "' is not supported; declare on the concrete agent."); + } + } + } for (Method method : agentClass.getDeclaredMethods()) { - if (method.isAnnotationPresent(org.apache.flink.agents.api.annotation.Action.class)) { - org.apache.flink.agents.api.annotation.Action actionAnnotation = - method.getAnnotation(org.apache.flink.agents.api.annotation.Action.class); - - String[] listenEventTypeStrings = - Objects.requireNonNull(actionAnnotation).listenEventTypes(); - - org.apache.flink.agents.plan.JavaFunction javaFunction = + if (!method.isAnnotationPresent(org.apache.flink.agents.api.annotation.Action.class)) { + continue; + } + org.apache.flink.agents.api.annotation.Action actionAnnotation = + Objects.requireNonNull( + method.getAnnotation( + org.apache.flink.agents.api.annotation.Action.class)); + String[] listenEventTypeStrings = actionAnnotation.listenEventTypes(); + org.apache.flink.agents.api.annotation.PythonFunction target = + actionAnnotation.target(); + String targetModule = target.module(); + String targetQualname = target.qualname(); + boolean moduleSet = !targetModule.isEmpty(); + boolean qualnameSet = !targetQualname.isEmpty(); + + org.apache.flink.agents.plan.Function execFunction; + if (!moduleSet && !qualnameSet) { + execFunction = new org.apache.flink.agents.plan.JavaFunction( method.getDeclaringClass(), method.getName(), method.getParameterTypes()); - extractActions(method.getName(), listenEventTypeStrings, javaFunction, null); + } else if (moduleSet && qualnameSet) { + execFunction = + new org.apache.flink.agents.plan.PythonFunction( + targetModule, targetQualname); + } else { + throw new IllegalStateException( + "PythonFunction target on '" + + method.getName() + + "' must set both module and qualname"); } + extractActions(method.getName(), listenEventTypeStrings, execFunction, null); } for (Map.Entry< diff --git a/plan/src/main/java/org/apache/flink/agents/plan/PythonFunction.java b/plan/src/main/java/org/apache/flink/agents/plan/PythonFunction.java index b5beabdcc..f141d1416 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/PythonFunction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/PythonFunction.java @@ -42,15 +42,21 @@ public void setInterpreter(PythonInterpreter interpreter) { @Override public Object call(Object... args) throws Exception { if (interpreter == null) { - throw new IllegalStateException("Python interpreter is not set."); + throw new IllegalStateException( + "PythonFunction requires the Python interpreter; not set on this " + + "descriptor. The runtime injects it via setInterpreter before " + + "invocation."); } return interpreter.invoke(CALL_PYTHON_FUNCTION, module, qualName, args); } - // TODO: check Python function signature compatibility with given parameter types @Override - public void checkSignature(Class[] parameterTypes) throws Exception {} + public void checkSignature(Class[] parameterTypes) throws Exception { + // No-op: descriptor carries no parameter types, and the Python module + // cannot be inspected without an interpreter. Mismatches surface at + // dispatch time. + } public String getModule() { return module; diff --git a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java index 2618721ff..fec2f126f 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializer.java @@ -158,7 +158,7 @@ private Map deserializeJavaConfig(JsonNode node) throws Exceptio return config; } - private Object deserializePythonConfig(JsonNode node) { + static Object deserializePythonConfig(JsonNode node) { if (node.isObject()) { Map map = new HashMap<>(); node.fields() @@ -172,8 +172,16 @@ private Object deserializePythonConfig(JsonNode node) { List list = new ArrayList<>(); node.forEach(element -> list.add(deserializePythonConfig(element))); return list; - } else if (node.isValueNode()) { - return node.asText(); + } else if (node.isNull()) { + return null; + } else if (node.isBoolean()) { + return node.booleanValue(); + } else if (node.isIntegralNumber()) { + return node.canConvertToInt() ? (Object) node.intValue() : (Object) node.longValue(); + } else if (node.isFloatingPointNumber()) { + return node.doubleValue(); + } else if (node.isTextual()) { + return node.textValue(); } else { throw new UnsupportedOperationException("Unsupported node type: " + node.getNodeType()); } diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanCrossLanguageTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanCrossLanguageTest.java new file mode 100644 index 000000000..3f55b0915 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanCrossLanguageTest.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.actions.Action; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * Layer B plan-compile tests for cross-language {@code Function} descriptors on the Java side. + * + *

Mirrors {@code test_agent_plan_cross_language.py}. Confirms: + * + *

    + *
  • api → plan promotion in {@link + * AgentPlan#toPlanFunction(org.apache.flink.agents.api.function.Function)} handles both + * {@code PythonFunction} (cross-language) and {@code JavaFunction} (same-language) + * descriptors. + *
  • The compiled plan serializes to the expected wire JSON (snake_case action {@code exec} + * block). + *
  • JSON round-trips back into a structurally equivalent plan. + *
  • Java can deserialize Python's plan snapshot referencing a cross-language Java action body. + *
+ * + *

Snapshots live under {@code e2e-test/cross-language-agent-plan-snapshots/}. + */ +class AgentPlanCrossLanguageTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + /** Static handler used as a JavaFunction target — must exist on the classpath for compile. */ + public static void handle(Event event, RunnerContext ctx) { + // No-op: plan-compile only needs the method to resolve via Class.forName / getMethod. + } + + private static Path snapshotDir; + + @BeforeAll + static void resolveSnapshotDir() { + // Maven sets user.dir to the module directory; repo root is its parent. + Path repoRoot = Paths.get(System.getProperty("user.dir")).getParent(); + snapshotDir = repoRoot.resolve("e2e-test/cross-language-agent-plan-snapshots"); + } + + // ── Helpers ──────────────────────────────────────────────────────────── + + private static boolean regenerateRequested() { + return Boolean.parseBoolean(System.getProperty("regenerate.snapshots", "false")); + } + + private static org.apache.flink.agents.api.function.JavaFunction javaFunctionDescriptor() { + return new org.apache.flink.agents.api.function.JavaFunction( + AgentPlanCrossLanguageTest.class.getName(), + "handle", + List.of(Event.class.getName(), RunnerContext.class.getName())); + } + + private static org.apache.flink.agents.api.function.PythonFunction pythonFunctionDescriptor() { + // Use a target that exists on the Python side too — Python's `Action.__init__` + // eagerly imports the module via `check_signature` during JSON deserialize, + // so the cross-language snapshot must point at a real importable callable. + // Mirror of `_dummy_action` in `test_agent_plan_cross_language.py`. + return new org.apache.flink.agents.api.function.PythonFunction( + "flink_agents.plan.tests.test_agent_plan_cross_language", "_dummy_action"); + } + + private static AgentPlan compileWithPythonAction() throws Exception { + Agent agent = new Agent(); + agent.addAction( + "handle", new String[] {InputEvent.EVENT_TYPE}, pythonFunctionDescriptor(), null); + return new AgentPlan(agent); + } + + private static AgentPlan compileWithJavaAction() throws Exception { + Agent agent = new Agent(); + agent.addAction( + "handle", new String[] {InputEvent.EVENT_TYPE}, javaFunctionDescriptor(), null); + return new AgentPlan(agent); + } + + // ── api → plan promotion (Java side) ─────────────────────────────────── + + @Test + void compileAgentWithJavaFunctionDescriptor() throws Exception { + AgentPlan plan = compileWithJavaAction(); + + Action action = plan.getActions().get("handle"); + assertThat(action).isNotNull(); + assertThat(action.getExec()).isInstanceOf(JavaFunction.class); + + JavaFunction planFn = (JavaFunction) action.getExec(); + assertThat(planFn.getQualName()).isEqualTo(AgentPlanCrossLanguageTest.class.getName()); + assertThat(planFn.getMethodName()).isEqualTo("handle"); + assertThat(planFn.getParameterTypes()).containsExactly(Event.class, RunnerContext.class); + } + + @Test + void compileAgentWithPythonFunctionDescriptor() throws Exception { + AgentPlan plan = compileWithPythonAction(); + + Action action = plan.getActions().get("handle"); + assertThat(action).isNotNull(); + assertThat(action.getExec()).isInstanceOf(PythonFunction.class); + + PythonFunction planFn = (PythonFunction) action.getExec(); + assertThat(planFn.getModule()) + .isEqualTo("flink_agents.plan.tests.test_agent_plan_cross_language"); + assertThat(planFn.getQualName()).isEqualTo("_dummy_action"); + } + + @Test + void compileWithJavaFunctionRequiresClassOnClasspath() { + // Java plan-compile resolves FQNs eagerly (Class.forName), so an unknown class must fail + // here, not later at dispatch. + Agent agent = new Agent(); + org.apache.flink.agents.api.function.JavaFunction fake = + new org.apache.flink.agents.api.function.JavaFunction( + "com.does.not.Exist", "ghost", List.of("java.lang.String")); + agent.addAction("act", new String[] {InputEvent.EVENT_TYPE}, fake, null); + + Throwable thrown = null; + try { + new AgentPlan(agent); + } catch (Throwable t) { + thrown = t; + } + assertThat(thrown) + .as("Java plan-compile should reject unresolvable JavaFunction class.") + .isNotNull(); + } + + // ── Plan JSON shape (Java side) ──────────────────────────────────────── + + @Test + void javaPlanWithJavaActionHasExpectedExecShape() throws Exception { + AgentPlan plan = compileWithJavaAction(); + JsonNode parsed = MAPPER.readTree(MAPPER.writeValueAsString(plan)); + JsonNode execBlock = parsed.get("actions").get("handle").get("exec"); + + assertThat(execBlock.get("func_type").asText()).isEqualTo("JavaFunction"); + assertThat(execBlock.get("qualname").asText()) + .isEqualTo(AgentPlanCrossLanguageTest.class.getName()); + assertThat(execBlock.get("method_name").asText()).isEqualTo("handle"); + JsonNode params = execBlock.get("parameter_types"); + assertThat(params.isArray()).isTrue(); + assertThat(params.get(0).asText()).isEqualTo(Event.class.getName()); + assertThat(params.get(1).asText()).isEqualTo(RunnerContext.class.getName()); + } + + @Test + void javaPlanWithPythonActionHasExpectedExecShape() throws Exception { + AgentPlan plan = compileWithPythonAction(); + JsonNode parsed = MAPPER.readTree(MAPPER.writeValueAsString(plan)); + JsonNode execBlock = parsed.get("actions").get("handle").get("exec"); + + assertThat(execBlock.get("func_type").asText()).isEqualTo("PythonFunction"); + assertThat(execBlock.get("module").asText()) + .isEqualTo("flink_agents.plan.tests.test_agent_plan_cross_language"); + assertThat(execBlock.get("qualname").asText()).isEqualTo("_dummy_action"); + } + + // ── Plan JSON round-trip (Java side) ─────────────────────────────────── + + @Test + void javaPlanWithJavaActionRoundTripsThroughJson() throws Exception { + AgentPlan plan = compileWithJavaAction(); + String json = MAPPER.writeValueAsString(plan); + AgentPlan restored = MAPPER.readValue(json, AgentPlan.class); + + Action action = restored.getActions().get("handle"); + assertThat(action.getExec()).isInstanceOf(JavaFunction.class); + JavaFunction jf = (JavaFunction) action.getExec(); + assertThat(jf.getQualName()).isEqualTo(AgentPlanCrossLanguageTest.class.getName()); + assertThat(jf.getMethodName()).isEqualTo("handle"); + } + + @Test + void javaPlanWithPythonActionRoundTripsThroughJson() throws Exception { + AgentPlan plan = compileWithPythonAction(); + String json = MAPPER.writeValueAsString(plan); + AgentPlan restored = MAPPER.readValue(json, AgentPlan.class); + + Action action = restored.getActions().get("handle"); + assertThat(action.getExec()).isInstanceOf(PythonFunction.class); + PythonFunction pf = (PythonFunction) action.getExec(); + assertThat(pf.getModule()) + .isEqualTo("flink_agents.plan.tests.test_agent_plan_cross_language"); + assertThat(pf.getQualName()).isEqualTo("_dummy_action"); + } + + // ── Cross-language snapshot (Java writes / Python reads) ─────────────── + + @Test + void regenerateJavaPlanWithPythonActionSnapshot() throws Exception { + assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to refresh."); + AgentPlan plan = compileWithPythonAction(); + String json = MAPPER.writerWithDefaultPrettyPrinter().writeValueAsString(plan); + + Path target = snapshotDir.resolve("java/agent_plan_with_python_action.json"); + Files.createDirectories(target.getParent()); + Files.writeString(target, json + "\n"); + } + + @Test + void javaPlanWithPythonActionSnapshotIsStable() throws Exception { + Path committed = snapshotDir.resolve("java/agent_plan_with_python_action.json"); + assertTrue( + Files.exists(committed), + "Java plan snapshot missing from " + + committed + + ". Regenerate with -Dregenerate.snapshots=true and commit alongside the test."); + + AgentPlan plan = compileWithPythonAction(); + JsonNode actual = MAPPER.readTree(MAPPER.writeValueAsString(plan)); + JsonNode expected = MAPPER.readTree(Files.readString(committed)); + assertThat(actual).isEqualTo(expected); + } + + @Test + void javaCanDeserializePythonPlanWithJavaAction() throws Exception { + Path snapshot = snapshotDir.resolve("python/agent_plan_with_java_action.json"); + assertTrue( + Files.exists(snapshot), + "Python plan snapshot missing from " + + snapshot + + ". Regenerate the Python side with REGENERATE_SNAPSHOTS=1 and commit alongside this test."); + + String json = Files.readString(snapshot); + AgentPlan restored = MAPPER.readValue(json, AgentPlan.class); + + Action handle = restored.getActions().get("handle"); + assertThat(handle).isNotNull(); + assertThat(handle.getExec()).isInstanceOf(JavaFunction.class); + JavaFunction jf = (JavaFunction) handle.getExec(); + assertThat(jf.getQualName()).isEqualTo("com.example.Handlers"); + assertThat(jf.getMethodName()).isEqualTo("handle"); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java index 7e7d08705..acdf14343 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanTest.java @@ -45,6 +45,7 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Test for {@link AgentPlan} constructor that takes an Agent. */ public class AgentPlanTest { @@ -246,6 +247,126 @@ public void testConstructorWithAgentNoActions() throws Exception { assertThat(agentPlan.getActionsByEvent().size()).isEqualTo(4); } + @Test + public void testBuiltInActionsAreJavaNativeAfterCompile() throws Exception { + AgentPlan agentPlan = new AgentPlan(new Agent() {}); + + for (String name : + List.of("chat_model_action", "tool_call_action", "context_retrieval_action")) { + Action action = agentPlan.getActions().get(name); + assertThat(action).isNotNull(); + assertThat(action.getExec()).isInstanceOf(JavaFunction.class); + } + } + + /** Cross-language action via {@code @Action(target = @PythonFunction(...))}. */ + public static class AgentWithCrossLanguageAction extends Agent { + @org.apache.flink.agents.api.annotation.Action( + listenEventTypes = {InputEvent.EVENT_TYPE}, + target = + @org.apache.flink.agents.api.annotation.PythonFunction( + module = "my_pkg.handlers", + qualname = "handle_input")) + public static void handle(Event event, RunnerContext ctx) { + throw new UnsupportedOperationException("cross-language stub"); + } + } + + @Test + public void testActionWithPythonTargetCompilesToPythonFunctionExec() throws Exception { + AgentPlan plan = new AgentPlan(new AgentWithCrossLanguageAction()); + + Action action = plan.getActions().get("handle"); + assertThat(action).isNotNull(); + assertThat(action.getExec()) + .as("non-empty target.module() must compile to a plan PythonFunction exec") + .isInstanceOf(org.apache.flink.agents.plan.PythonFunction.class); + + org.apache.flink.agents.plan.PythonFunction exec = + (org.apache.flink.agents.plan.PythonFunction) action.getExec(); + assertThat(exec.getModule()).isEqualTo("my_pkg.handlers"); + assertThat(exec.getQualName()).isEqualTo("handle_input"); + assertThat(action.getListenEventTypes()).containsExactly(InputEvent.EVENT_TYPE); + } + + /** Plain {@code @Action} (no {@code target}) compiles to a native Java exec. */ + public static class AgentWithNativeJavaAction extends Agent { + @org.apache.flink.agents.api.annotation.Action(listenEventTypes = {InputEvent.EVENT_TYPE}) + public static void handle(Event event, RunnerContext ctx) { + // intentionally empty + } + } + + @Test + public void testActionWithEmptyTargetCompilesToJavaFunctionExec() throws Exception { + AgentPlan plan = new AgentPlan(new AgentWithNativeJavaAction()); + + Action action = plan.getActions().get("handle"); + assertThat(action).isNotNull(); + assertThat(action.getExec()) + .as("empty target.module() must compile to a plan JavaFunction exec") + .isInstanceOf(JavaFunction.class); + } + + /** Partially-set target (module without qualname) — must be rejected at compile. */ + public static class AgentWithHalfSetPythonTargetMissingQualname extends Agent { + @org.apache.flink.agents.api.annotation.Action( + listenEventTypes = {InputEvent.EVENT_TYPE}, + target = @org.apache.flink.agents.api.annotation.PythonFunction(module = "pkg")) + public static void handle(Event event, RunnerContext ctx) { + throw new UnsupportedOperationException("cross-language stub"); + } + } + + @Test + public void testActionWithPythonTargetMissingQualnameIsRejected() { + assertThatThrownBy(() -> new AgentPlan(new AgentWithHalfSetPythonTargetMissingQualname())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("handle") + .hasMessageContaining("qualname"); + } + + /** Partially-set target (qualname without module) — must be rejected at compile. */ + public static class AgentWithHalfSetPythonTargetMissingModule extends Agent { + @org.apache.flink.agents.api.annotation.Action( + listenEventTypes = {InputEvent.EVENT_TYPE}, + target = + @org.apache.flink.agents.api.annotation.PythonFunction( + qualname = "handle_input")) + public static void handle(Event event, RunnerContext ctx) { + throw new UnsupportedOperationException("cross-language stub"); + } + } + + @Test + public void testActionWithPythonTargetMissingModuleIsRejected() { + assertThatThrownBy(() -> new AgentPlan(new AgentWithHalfSetPythonTargetMissingModule())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("handle") + .hasMessageContaining("module"); + } + + /** + * @Action declared on a parent agent class — must be rejected loudly, not silently dropped. + */ + public abstract static class BaseAgentWithInheritedAction extends Agent { + @org.apache.flink.agents.api.annotation.Action(listenEventTypes = {InputEvent.EVENT_TYPE}) + public static void sharedAction(Event event, RunnerContext ctx) { + throw new UnsupportedOperationException("test stub"); + } + } + + public static class ConcreteAgentInheritingAction extends BaseAgentWithInheritedAction {} + + @Test + public void testActionInheritedFromParentAgentClassIsRejected() { + assertThatThrownBy(() -> new AgentPlan(new ConcreteAgentInheritingAction())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("sharedAction") + .hasMessageContaining("BaseAgentWithInheritedAction") + .hasMessageContaining("Inherited @Action"); + } + @Test public void testAgentPlanResourceProviders() throws Exception { // Test that AgentPlan initializes resource providers correctly diff --git a/plan/src/test/java/org/apache/flink/agents/plan/PlanFunctionDispatchTest.java b/plan/src/test/java/org/apache/flink/agents/plan/PlanFunctionDispatchTest.java new file mode 100644 index 000000000..e11a80ac3 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/PlanFunctionDispatchTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Dispatch tests for plan-layer {@link Function} invocation. */ +class PlanFunctionDispatchTest { + + private static int invocationCount; + + public static void handle(Event event, RunnerContext ctx) { + invocationCount += 1; + } + + @Test + void javaFunctionDispatchInvokesUnderlyingMethodWithPositionalArgs() throws Exception { + invocationCount = 0; + JavaFunction fn = + new JavaFunction( + PlanFunctionDispatchTest.class, + "handle", + new Class[] {Event.class, RunnerContext.class}); + + fn.call(new InputEvent(new HashMap<>()), null); + + assertThat(invocationCount).isEqualTo(1); + } + + @Test + void pythonFunctionDispatchFailsWithoutInterpreter() { + PythonFunction fn = new PythonFunction("test.module", "test_handler"); + + assertThatThrownBy(() -> fn.call(new InputEvent(new HashMap<>()), null)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("PythonFunction requires the Python interpreter"); + } + + @Test + void pythonFunctionCheckSignatureIsLazyNoOpForAnyArity() throws Exception { + PythonFunction fn = new PythonFunction("test.module", "test_handler"); + + fn.checkSignature(new Class[] {Event.class, RunnerContext.class}); + fn.checkSignature(new Class[] {}); + fn.checkSignature(new Class[] {Event.class}); + } +} diff --git a/plan/src/test/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializerTest.java b/plan/src/test/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializerTest.java index 49da54fac..8f32d512b 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializerTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/serializer/ActionJsonDeserializerTest.java @@ -18,6 +18,7 @@ package org.apache.flink.agents.plan.serializer; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.InputEvent; @@ -28,9 +29,12 @@ import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; /** Test for {@link ActionJsonDeserializer}. */ @@ -106,4 +110,30 @@ public void testDeserializeInvalidEventType() throws IOException { ObjectMapper mapper = new ObjectMapper(); assertThrows(RuntimeException.class, () -> mapper.readValue(json, Action.class)); } + + @Test + public void testDeserializePythonConfigPreservesPrimitiveTypes() throws IOException { + JsonNode node = + new ObjectMapper() + .readTree( + "{\"timeout_sec\": 30," + + " \"big\": 10000000000," + + " \"enabled\": true," + + " \"rate\": 1.5," + + " \"label\": \"fast\"," + + " \"extra\": null}"); + + @SuppressWarnings("unchecked") + Map result = + (Map) ActionJsonDeserializer.deserializePythonConfig(node); + + assertThat(result) + .containsEntry("timeout_sec", 30) + .containsEntry("big", 10_000_000_000L) + .containsEntry("enabled", true) + .containsEntry("rate", 1.5) + .containsEntry("label", "fast") + .containsKey("extra"); + assertNull(result.get("extra")); + } } diff --git a/python/flink_agents/api/decorators.py b/python/flink_agents/api/decorators.py index b4f0117ba..6025ac393 100644 --- a/python/flink_agents/api/decorators.py +++ b/python/flink_agents/api/decorators.py @@ -17,8 +17,25 @@ ################################################################################# from typing import Callable, Type +from flink_agents.api.function import Function, JavaFunction, PythonFunction -def action(*listen_events: str) -> Callable: + +def _validate_target(target: Function, owner: str) -> None: + """Reject targets with empty required identifiers, attributed to ``owner``.""" + if isinstance(target, PythonFunction): + if not target.module or not target.qualname: + msg = f"PythonFunction target on '{owner}' must set both module and qualname" + raise ValueError(msg) + elif isinstance(target, JavaFunction): + if not target.qualname or not target.method_name: + msg = f"JavaFunction target on '{owner}' must set both qualname and method_name" + raise ValueError(msg) + + +def action( + *listen_events: str, + target: Function | None = None, +) -> Callable: """Decorator for marking a function as an agent action. Each argument is a type-identifier string that this action responds to. @@ -27,6 +44,10 @@ def action(*listen_events: str) -> Callable: ---------- listen_events : str Type-identifier strings that this action responds to. + target : Function, optional + Cross-language function descriptor dispatched instead of the + decorated body. The body becomes a stub — raise + ``NotImplementedError`` so direct calls fail loud. Returns: ------- @@ -37,6 +58,8 @@ def action(*listen_events: str) -> Callable: ------ AssertionError If no events are provided or if an argument is not a string. + TypeError + If ``target`` is provided but is not a :class:`Function` descriptor. """ assert len(listen_events) > 0, ( "action must have at least one event type to listen to" @@ -47,7 +70,17 @@ def action(*listen_events: str) -> Callable: f"action must listen to string type identifiers, got {evt!r}" ) + if target is not None and not isinstance(target, Function): + msg = ( + f"action(target=...) must be an api-layer Function descriptor, " + f"got {type(target).__name__}" + ) + raise TypeError(msg) + def decorator(func: Callable) -> Callable: + if target is not None: + _validate_target(target, func.__qualname__) + func._target = target func._listen_events = listen_events return func diff --git a/python/flink_agents/api/events/chat_event.py b/python/flink_agents/api/events/chat_event.py index 3423f70b4..30790b28a 100644 --- a/python/flink_agents/api/events/chat_event.py +++ b/python/flink_agents/api/events/chat_event.py @@ -72,11 +72,13 @@ def from_event(cls, event: Event) -> "ChatRequestEvent": output_schema_raw = event.attributes.get("output_schema") if isinstance(output_schema_raw, dict): output_schema_raw = OutputSchema.model_validate(output_schema_raw) - return ChatRequestEvent( + result = ChatRequestEvent( model=event.attributes["model"], messages=messages, output_schema=output_schema_raw, ) + result.id = event.id + return result @property def model(self) -> str: @@ -140,17 +142,20 @@ def from_event(cls, event: Event) -> "ChatResponseEvent": if isinstance(response_raw, dict) else response_raw ) - return ChatResponseEvent( + result = ChatResponseEvent( request_id=event.attributes["request_id"], response=response, retry_count=event.attributes.get("retry_count", 0), total_retry_wait_sec=event.attributes.get("total_retry_wait_sec", 0), ) + result.id = event.id + return result @property def request_id(self) -> UUID: """Return the request event ID.""" - return self.get_attr("request_id") + val = self.get_attr("request_id") + return UUID(val) if isinstance(val, str) else val @property def response(self) -> ChatMessage: diff --git a/python/flink_agents/api/events/context_retrieval_event.py b/python/flink_agents/api/events/context_retrieval_event.py index c67a03e97..a8245a4ec 100644 --- a/python/flink_agents/api/events/context_retrieval_event.py +++ b/python/flink_agents/api/events/context_retrieval_event.py @@ -58,11 +58,13 @@ def __init__(self, query: str, vector_store: str, max_results: int = 3) -> None: def from_event(cls, event: Event) -> "ContextRetrievalRequestEvent": assert "query" in event.attributes assert "vector_store" in event.attributes - return ContextRetrievalRequestEvent( + result = ContextRetrievalRequestEvent( query=event.attributes["query"], vector_store=event.attributes["vector_store"], max_results=event.attributes.get("max_results", 3), ) + result.id = event.id + return result @property def query(self) -> str: @@ -117,16 +119,19 @@ def from_event(cls, event: Event) -> "ContextRetrievalResponseEvent": Document.model_validate(d) if isinstance(d, dict) else d for d in documents_raw ] - return ContextRetrievalResponseEvent( + result = ContextRetrievalResponseEvent( request_id=event.attributes["request_id"], query=event.attributes["query"], documents=documents, ) + result.id = event.id + return result @property def request_id(self) -> UUID: """Return the request event ID.""" - return self.get_attr("request_id") + val = self.get_attr("request_id") + return UUID(val) if isinstance(val, str) else val @property def query(self) -> str: diff --git a/python/flink_agents/api/events/event.py b/python/flink_agents/api/events/event.py index f0dd480f1..77e7e5faf 100644 --- a/python/flink_agents/api/events/event.py +++ b/python/flink_agents/api/events/event.py @@ -199,7 +199,9 @@ def __init__(self, input: Any) -> None: @override def from_event(cls, event: Event) -> "InputEvent": assert "input" in event.attributes - return InputEvent(input=event.attributes["input"]) + result = InputEvent(input=event.attributes["input"]) + result.id = event.id + return result @property def input(self) -> Any: @@ -230,7 +232,9 @@ def __init__(self, output: Any) -> None: @override def from_event(cls, event: Event) -> "OutputEvent": assert "output" in event.attributes - return OutputEvent(output=event.attributes["output"]) + result = OutputEvent(output=event.attributes["output"]) + result.id = event.id + return result @property def output(self) -> Any: diff --git a/python/flink_agents/api/events/tool_event.py b/python/flink_agents/api/events/tool_event.py index 64454868f..9acc865a4 100644 --- a/python/flink_agents/api/events/tool_event.py +++ b/python/flink_agents/api/events/tool_event.py @@ -54,10 +54,12 @@ def __init__(self, model: str, tool_calls: List[Dict[str, Any]]) -> None: def from_event(cls, event: Event) -> "ToolRequestEvent": assert "model" in event.attributes assert "tool_calls" in event.attributes - return ToolRequestEvent( + result = ToolRequestEvent( model=event.attributes["model"], tool_calls=event.attributes["tool_calls"], ) + result.id = event.id + return result @property def model(self) -> str: @@ -108,16 +110,19 @@ def from_event(cls, event: Event) -> "ToolResponseEvent": assert "request_id" in event.attributes assert "responses" in event.attributes assert "external_ids" in event.attributes - return ToolResponseEvent( + result = ToolResponseEvent( request_id=event.attributes["request_id"], responses=event.attributes["responses"], external_ids=event.attributes["external_ids"], ) + result.id = event.id + return result @property def request_id(self) -> UUID: """Return the request event ID.""" - return self.get_attr("request_id") + val = self.get_attr("request_id") + return UUID(val) if isinstance(val, str) else val @property def responses(self) -> Dict[UUID, Any]: diff --git a/python/flink_agents/api/tests/test_agent_add_action.py b/python/flink_agents/api/tests/test_agent_add_action.py new file mode 100644 index 000000000..98b448b03 --- /dev/null +++ b/python/flink_agents/api/tests/test_agent_add_action.py @@ -0,0 +1,155 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Layer A API-surface tests for ``Agent.add_action`` (Python side). Mirrors Java's ``AgentAddActionTest``.""" + +import pytest + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.events.event import Event, InputEvent +from flink_agents.api.function import JavaFunction, PythonFunction +from flink_agents.api.runner_context import RunnerContext + + +def _dummy_action(event: Event, ctx: RunnerContext) -> None: + """Plain Python callable used as an action body.""" + + +def _make_java_function() -> JavaFunction: + return JavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + + +# ── Descriptor pass-through ───────────────────────────────────────────── + + +def test_add_action_accepts_python_function_descriptor_and_stores_as_is() -> None: + """Pre-built PythonFunction is stored verbatim — no re-wrapping.""" + agent = Agent() + pf = PythonFunction(module="pkg.mod", qualname="MyClass.method") + + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=pf) + + events, stored, config = agent.actions["act"] + assert events == [InputEvent.EVENT_TYPE] + assert stored is pf, "PythonFunction descriptor should not be re-wrapped." + assert config is None + + +def test_add_action_accepts_java_function_descriptor_and_stores_as_is() -> None: + """Cross-language case: JavaFunction is stored verbatim, never wrapped.""" + agent = Agent() + jf = _make_java_function() + + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=jf) + + events, stored, _ = agent.actions["act"] + assert events == [InputEvent.EVENT_TYPE] + assert stored is jf, "JavaFunction descriptor should be stored as-is." + assert isinstance(stored, JavaFunction) + + +# ── Callable wrapping ─────────────────────────────────────────────────── + + +def test_add_action_wraps_raw_callable_as_python_function() -> None: + """Bare callables get auto-wrapped into a PythonFunction descriptor.""" + agent = Agent() + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=_dummy_action) + + _, stored, _ = agent.actions["act"] + assert isinstance(stored, PythonFunction) + assert stored.qualname == "_dummy_action" + assert stored.module == _dummy_action.__module__ + + +# ── Duplicate name rejection ──────────────────────────────────────────── + + +def test_add_action_duplicate_name_rejected_python_function() -> None: + agent = Agent() + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=_dummy_action) + with pytest.raises(ValueError, match="act"): + agent.add_action( + name="act", events=[InputEvent.EVENT_TYPE], func=_dummy_action + ) + + +def test_add_action_duplicate_name_rejected_java_function() -> None: + """Duplicate-name rejection applies uniformly regardless of descriptor type.""" + agent = Agent() + agent.add_action( + name="act", events=[InputEvent.EVENT_TYPE], func=_make_java_function() + ) + with pytest.raises(ValueError, match="act"): + agent.add_action( + name="act", events=[InputEvent.EVENT_TYPE], func=_make_java_function() + ) + + +# ── Config capture ────────────────────────────────────────────────────── + + +def test_add_action_captures_config_kwargs() -> None: + agent = Agent() + agent.add_action( + name="act", + events=[InputEvent.EVENT_TYPE], + func=_dummy_action, + retry=3, + timeout_sec=10, + ) + + _, _, config = agent.actions["act"] + assert config == {"retry": 3, "timeout_sec": 10} + + +def test_add_action_config_is_none_when_no_kwargs() -> None: + agent = Agent() + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=_dummy_action) + + _, _, config = agent.actions["act"] + assert config is None + + +def test_add_action_returns_self_for_chaining() -> None: + agent = Agent() + result = agent.add_action( + name="act", events=[InputEvent.EVENT_TYPE], func=_dummy_action + ) + assert result is agent + + +# ── Multi-event subscription ──────────────────────────────────────────── + + +def test_add_action_supports_multiple_event_subscriptions() -> None: + agent = Agent() + agent.add_action( + name="multi", + events=["evt_a", "evt_b", "evt_c"], + func=_make_java_function(), + ) + + events, _, _ = agent.actions["multi"] + assert events == ["evt_a", "evt_b", "evt_c"] diff --git a/python/flink_agents/api/tests/test_cross_language_event_snapshots.py b/python/flink_agents/api/tests/test_cross_language_event_snapshots.py new file mode 100644 index 000000000..b6589c3f1 --- /dev/null +++ b/python/flink_agents/api/tests/test_cross_language_event_snapshots.py @@ -0,0 +1,481 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Cross-language event SerDe snapshot tests.""" + +import json +import os +from pathlib import Path +from typing import ClassVar +from uuid import UUID + +import pytest + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.context_retrieval_event import ( + ContextRetrievalRequestEvent, + ContextRetrievalResponseEvent, +) +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent +from flink_agents.api.vector_stores.vector_store import Document + +_REPO_ROOT = Path(__file__).resolve().parents[4] +_SNAPSHOT_DIR = _REPO_ROOT / "e2e-test" / "cross-language-event-snapshots" + +_FIXED_EVENT_ID = UUID("00000000-0000-0000-0000-000000000001") +_FIXED_REQUEST_ID = UUID("00000000-0000-0000-0000-000000000002") +_FIXED_TOOL_CALL_ID = "call_aaaa" +_FIXED_TOOL_CALL_ID_NUMERIC = "call_bbbb" +_FIXED_TOOL_CALL_ID_BOOL = "call_cccc" + + +def _regenerate_enabled() -> bool: + return os.environ.get("REGENERATE_SNAPSHOTS", "").lower() in {"1", "true", "yes"} + + +def _force_id(event: Event, fixed_id: UUID) -> Event: + object.__setattr__(event, "id", fixed_id) + return event + + +def _write_python_snapshot(name: str, event: Event) -> None: + target = _SNAPSHOT_DIR / "python" / name + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(event.model_dump_json(indent=2) + "\n") + + +def _assert_python_snapshot_stable(name: str, event: Event) -> None: + actual = json.loads(event.model_dump_json()) + committed_path = _SNAPSHOT_DIR / "python" / name + assert committed_path.exists(), ( + f"Python snapshot {name} missing from {committed_path}. " + f"If you added a new event, regenerate with REGENERATE_SNAPSHOTS=1 " + f"and commit alongside the test." + ) + expected = json.loads(committed_path.read_text()) + assert actual == expected, ( + f"Python serialization of {name} drifted from committed snapshot." + ) + + +def _read_java_snapshot(name: str) -> Event: + java_snapshot = _SNAPSHOT_DIR / "java" / name + assert java_snapshot.exists(), ( + f"Java snapshot {name} missing from {java_snapshot}. " + f"Regenerate the Java side with -Dregenerate.snapshots=true " + f"and commit alongside this test." + ) + return Event.from_json(java_snapshot.read_text()) + + +# ── InputEvent ────────────────────────────────────────────────────────── + + +def _build_input_event() -> InputEvent: + return _force_id(InputEvent(input="hello"), _FIXED_EVENT_ID) + + +def test_regenerate_input_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("input_event.json", _build_input_event()) + + +def test_input_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable("input_event.json", _build_input_event()) + + +def test_python_can_deserialize_input_event_from_java_snapshot() -> None: + base = _read_java_snapshot("input_event.json") + typed = InputEvent.from_event(base) + assert typed.input == "hello", "InputEvent.input mismatch." + assert typed.type == InputEvent.EVENT_TYPE + + +# ── OutputEvent ───────────────────────────────────────────────────────── + + +def _build_output_event() -> OutputEvent: + return _force_id(OutputEvent(output="world"), _FIXED_EVENT_ID) + + +def test_regenerate_output_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("output_event.json", _build_output_event()) + + +def test_output_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable("output_event.json", _build_output_event()) + + +def test_python_can_deserialize_output_event_from_java_snapshot() -> None: + base = _read_java_snapshot("output_event.json") + typed = OutputEvent.from_event(base) + assert typed.output == "world", "OutputEvent.output mismatch." + assert typed.type == OutputEvent.EVENT_TYPE + + +# ── ChatRequestEvent ──────────────────────────────────────────────────── + + +def _build_chat_request_event() -> ChatRequestEvent: + event = ChatRequestEvent( + model="test-model", + messages=[ChatMessage(role=MessageRole.USER, content="hello world")], + ) + return _force_id(event, _FIXED_EVENT_ID) + + +def test_regenerate_chat_request_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("chat_request_event.json", _build_chat_request_event()) + + +def test_chat_request_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "chat_request_event.json", _build_chat_request_event() + ) + + +def test_python_can_deserialize_chat_request_event_from_java_snapshot() -> None: + base = _read_java_snapshot("chat_request_event.json") + typed = ChatRequestEvent.from_event(base) + assert typed.model == "test-model" + assert len(typed.messages) == 1 + msg = typed.messages[0] + assert msg.role == MessageRole.USER, f"Role mismatch: got {msg.role!r}" + assert msg.content == "hello world" + + +def test_chat_request_row_type_info_output_schema_is_not_portable_across_languages_known_gap() -> None: + """Known 0.3 gap — RowTypeInfo-typed output_schema does not round-trip across the language + boundary. Python emits ``{"names": [...], "types": []}`` while Java + emits ``{"fieldNames": [...], "types": []}``, so a ChatRequestEvent carrying a + RowTypeInfo schema cannot be deserialized on the other side. The BaseModel (Pydantic class) + branch is symmetric and works. Reconciling the RowTypeInfo wire format requires a canonical + shape + bilateral OutputSchema serdes shims; tracked as a follow-up. + """ + from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo + + from flink_agents.api.agents.types import OutputSchema + + schema = OutputSchema( + output_schema=RowTypeInfo( + field_types=[BasicTypeInfo.STRING_TYPE_INFO()], + field_names=["name"], + ), + ) + event = ChatRequestEvent( + model="test-model", + messages=[ChatMessage(role=MessageRole.USER, content="hi")], + output_schema=schema, + ) + payload = event.model_dump_json() + # Pin Python's local shape so a future regression can't silently change it. The gap with + # Java's `{"fieldNames": ...}` shape is the documented limitation, not the assertion. + assert "\"names\"" in payload + assert "\"fieldNames\"" not in payload + + +# ── ChatResponseEvent ─────────────────────────────────────────────────── + + +def _build_chat_response_event() -> ChatResponseEvent: + event = ChatResponseEvent( + request_id=_FIXED_REQUEST_ID, + response=ChatMessage(role=MessageRole.ASSISTANT, content="hi there"), + ) + return _force_id(event, _FIXED_EVENT_ID) + + +def test_regenerate_chat_response_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("chat_response_event.json", _build_chat_response_event()) + + +def test_chat_response_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "chat_response_event.json", _build_chat_response_event() + ) + + +def test_python_can_deserialize_chat_response_event_from_java_snapshot() -> None: + base = _read_java_snapshot("chat_response_event.json") + typed = ChatResponseEvent.from_event(base) + expected_request_id = str(_FIXED_REQUEST_ID) + actual_request_id = ( + str(typed.request_id) if not isinstance(typed.request_id, str) else typed.request_id + ) + assert actual_request_id == expected_request_id, "request_id mismatch." + assert typed.response is not None, "response is None." + assert typed.response.role == MessageRole.ASSISTANT, ( + f"Response role mismatch: got {typed.response.role!r}" + ) + assert typed.response.content == "hi there" + + +# ── ToolRequestEvent ──────────────────────────────────────────────────── + + +def _build_tool_request_event() -> ToolRequestEvent: + tool_call = {"id": _FIXED_TOOL_CALL_ID, "name": "echo", "arguments": {"value": "ping"}} + event = ToolRequestEvent(model="test-model", tool_calls=[tool_call]) + return _force_id(event, _FIXED_EVENT_ID) + + +def test_regenerate_tool_request_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("tool_request_event.json", _build_tool_request_event()) + + +def test_tool_request_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "tool_request_event.json", _build_tool_request_event() + ) + + +def test_python_can_deserialize_tool_request_event_from_java_snapshot() -> None: + base = _read_java_snapshot("tool_request_event.json") + typed = ToolRequestEvent.from_event(base) + assert typed.model == "test-model" + assert len(typed.tool_calls) == 1 + assert typed.tool_calls[0]["id"] == _FIXED_TOOL_CALL_ID + + +# ── ToolResponseEvent ─────────────────────────────────────────────────── + + +def _build_tool_response_event() -> ToolResponseEvent: + # Mixed scalar value types pin the Python -> Java round-trip on the Java + # ToolResponseEvent.fromEvent fall-through that wraps non-ToolResponse/Map + # values via ToolResponse.success(v). + event = ToolResponseEvent( + request_id=_FIXED_REQUEST_ID, + responses={ + _FIXED_TOOL_CALL_ID: "pong", + _FIXED_TOOL_CALL_ID_NUMERIC: 42, + _FIXED_TOOL_CALL_ID_BOOL: True, + }, + external_ids={ + _FIXED_TOOL_CALL_ID: None, + _FIXED_TOOL_CALL_ID_NUMERIC: None, + _FIXED_TOOL_CALL_ID_BOOL: None, + }, + ) + return _force_id(event, _FIXED_EVENT_ID) + + +def test_regenerate_tool_response_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("tool_response_event.json", _build_tool_response_event()) + + +def test_tool_response_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "tool_response_event.json", _build_tool_response_event() + ) + + +def test_java_tool_response_event_is_shape_mismatched_when_consumed_by_python() -> None: + base = _read_java_snapshot("tool_response_event.json") + typed = ToolResponseEvent.from_event(base) + + assert typed.request_id == _FIXED_REQUEST_ID + + response_value = typed.responses[_FIXED_TOOL_CALL_ID] + assert isinstance(response_value, dict) + assert "result" in response_value + + assert "success" not in typed.attributes + assert "error" not in typed.attributes + assert "timestamp" not in typed.attributes + + +# ── ContextRetrievalRequestEvent ──────────────────────────────────────── + + +def _build_context_retrieval_request_event() -> ContextRetrievalRequestEvent: + event = ContextRetrievalRequestEvent( + query="what is flink", + vector_store="test-store", + max_results=5, + ) + return _force_id(event, _FIXED_EVENT_ID) + + +def test_regenerate_context_retrieval_request_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot( + "context_retrieval_request_event.json", + _build_context_retrieval_request_event(), + ) + + +def test_context_retrieval_request_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "context_retrieval_request_event.json", + _build_context_retrieval_request_event(), + ) + + +def test_python_can_deserialize_context_retrieval_request_event_from_java_snapshot() -> None: + base = _read_java_snapshot("context_retrieval_request_event.json") + typed = ContextRetrievalRequestEvent.from_event(base) + assert typed.query == "what is flink" + assert typed.vector_store == "test-store" + assert typed.max_results == 5 + + +# ── ContextRetrievalResponseEvent ─────────────────────────────────────── + + +def _build_context_retrieval_response_event() -> ContextRetrievalResponseEvent: + doc = Document(content="doc content", metadata={"k": "v"}, id="doc-1") + event = ContextRetrievalResponseEvent( + request_id=_FIXED_REQUEST_ID, + query="what is flink", + documents=[doc], + ) + return _force_id(event, _FIXED_EVENT_ID) + + +def test_regenerate_context_retrieval_response_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot( + "context_retrieval_response_event.json", + _build_context_retrieval_response_event(), + ) + + +def test_context_retrieval_response_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "context_retrieval_response_event.json", + _build_context_retrieval_response_event(), + ) + + +def test_python_can_deserialize_context_retrieval_response_event_from_java_snapshot() -> None: + base = _read_java_snapshot("context_retrieval_response_event.json") + typed = ContextRetrievalResponseEvent.from_event(base) + expected_request_id = str(_FIXED_REQUEST_ID) + actual_request_id = ( + str(typed.request_id) if not isinstance(typed.request_id, str) else typed.request_id + ) + assert actual_request_id == expected_request_id + assert typed.query == "what is flink" + assert len(typed.documents) == 1 + assert typed.documents[0].content == "doc content" + assert typed.documents[0].id == "doc-1" + + +# ── Generic Event with primitive attributes (user-authored axis) ─────── + + +_GENERIC_EVENT_TYPE = "_my_custom_event" +_GENERIC_EVENT_ATTRS = { + "k_int": 42, + "k_float": 1.5, + "k_bool": True, + "k_str": "hello", + "k_null": None, + "k_list": [1, 2, 3], + "k_dict": {"nested": "value"}, +} + + +def _build_generic_event() -> Event: + return _force_id( + Event(type=_GENERIC_EVENT_TYPE, attributes=dict(_GENERIC_EVENT_ATTRS)), + _FIXED_EVENT_ID, + ) + + +def test_regenerate_generic_event_python_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot("generic_event_with_attrs.json", _build_generic_event()) + + +def test_generic_event_python_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "generic_event_with_attrs.json", _build_generic_event() + ) + + +def test_python_can_deserialize_generic_event_from_java_snapshot() -> None: + base = _read_java_snapshot("generic_event_with_attrs.json") + + assert base.type == _GENERIC_EVENT_TYPE + assert base.attributes["k_int"] == 42 + assert isinstance(base.attributes["k_int"], int) + assert base.attributes["k_float"] == 1.5 + assert isinstance(base.attributes["k_float"], float) + assert base.attributes["k_bool"] is True + assert base.attributes["k_str"] == "hello" + assert base.attributes["k_null"] is None + assert base.attributes["k_list"] == [1, 2, 3] + assert base.attributes["k_dict"] == {"nested": "value"} + + +# ── Python-only subclass with no Java counterpart (graceful fallback) ── + + +class _MyPythonOnlyEvent(Event): + EVENT_TYPE: ClassVar[str] = "_my_python_only_event" + + def __init__(self, value: str, count: int) -> None: + super().__init__( + type=_MyPythonOnlyEvent.EVENT_TYPE, + attributes={"value": value, "count": count}, + ) + + +def _build_python_only_subclass_event() -> _MyPythonOnlyEvent: + return _force_id( + _MyPythonOnlyEvent(value="ping", count=7), + _FIXED_EVENT_ID, + ) + + +def test_regenerate_python_only_subclass_event_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + _write_python_snapshot( + "python_only_subclass_event.json", _build_python_only_subclass_event() + ) + + +def test_python_only_subclass_event_snapshot_is_stable() -> None: + _assert_python_snapshot_stable( + "python_only_subclass_event.json", _build_python_only_subclass_event() + ) + + +# ── Smoke ─────────────────────────────────────────────────────────────── + + +def test_snapshot_directory_exists() -> None: + assert _SNAPSHOT_DIR.is_dir(), f"Expected snapshot directory at {_SNAPSHOT_DIR}" diff --git a/python/flink_agents/api/tests/test_decorators.py b/python/flink_agents/api/tests/test_decorators.py index 4f4df0cf7..6a4b131fc 100644 --- a/python/flink_agents/api/tests/test_decorators.py +++ b/python/flink_agents/api/tests/test_decorators.py @@ -19,6 +19,7 @@ from flink_agents.api.decorators import action from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.function import JavaFunction, PythonFunction from flink_agents.api.runner_context import RunnerContext @@ -90,3 +91,87 @@ def test_action_decorator_rejects_invalid_types() -> None: @action(42) # type: ignore[arg-type] def bad_handler(event: Event, ctx: RunnerContext) -> None: pass + + +def _java_target() -> JavaFunction: + return JavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + + +def test_action_decorator_with_cross_language_target() -> None: + target = _java_target() + + @action(InputEvent.EVENT_TYPE, target=target) + def stub(event: Event, ctx: RunnerContext) -> None: + msg = "cross-language stub" + raise NotImplementedError(msg) + + assert stub._listen_events == (InputEvent.EVENT_TYPE,) + assert stub._target is target + + +def test_action_decorator_rejects_non_function_target() -> None: + with pytest.raises(TypeError, match="api-layer Function descriptor"): + + @action(InputEvent.EVENT_TYPE, target="not a function") # type: ignore[arg-type] + def stub(event: Event, ctx: RunnerContext) -> None: + pass + + +def test_action_decorator_without_target_does_not_set_attribute() -> None: + @action(InputEvent.EVENT_TYPE) + def regular(event: Event, ctx: RunnerContext) -> None: + pass + + assert not hasattr(regular, "_target") + + +def test_action_decorator_rejects_java_target_with_empty_qualname() -> None: + bad = JavaFunction(qualname="", method_name="handle", parameter_types=[]) + with pytest.raises(ValueError, match="qualname"): + + @action(InputEvent.EVENT_TYPE, target=bad) + def stub(event: Event, ctx: RunnerContext) -> None: + pass + + +def test_action_decorator_rejects_java_target_with_empty_method_name() -> None: + bad = JavaFunction(qualname="com.example.X", method_name="", parameter_types=[]) + with pytest.raises(ValueError, match="method_name"): + + @action(InputEvent.EVENT_TYPE, target=bad) + def stub(event: Event, ctx: RunnerContext) -> None: + pass + + +def test_action_decorator_rejects_python_target_with_empty_module() -> None: + bad = PythonFunction(module="", qualname="handle") + with pytest.raises(ValueError, match="module"): + + @action(InputEvent.EVENT_TYPE, target=bad) + def stub(event: Event, ctx: RunnerContext) -> None: + pass + + +def test_action_decorator_rejects_python_target_with_empty_qualname() -> None: + bad = PythonFunction(module="pkg.mod", qualname="") + with pytest.raises(ValueError, match="qualname"): + + @action(InputEvent.EVENT_TYPE, target=bad) + def stub(event: Event, ctx: RunnerContext) -> None: + pass + + +def test_action_decorator_target_error_names_decorated_function() -> None: + bad = PythonFunction(module="pkg.mod", qualname="") + with pytest.raises(ValueError, match="my_named_stub"): + + @action(InputEvent.EVENT_TYPE, target=bad) + def my_named_stub(event: Event, ctx: RunnerContext) -> None: + pass diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_action_handler.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_action_handler.py new file mode 100644 index 000000000..e613e9d48 --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_action_handler.py @@ -0,0 +1,24 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.runner_context import RunnerContext + + +def multiply_by_two(event: Event, ctx: RunnerContext) -> None: + value = InputEvent.from_event(event).input + ctx.send_event(OutputEvent(output=value * 2)) diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_agent_with_java_action.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_agent_with_java_action.py new file mode 100644 index 000000000..a59b46fed --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_agent_with_java_action.py @@ -0,0 +1,56 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Python agent whose action body is a Java static method (Java→Python mirror).""" + +from pyflink.datastream import KeySelector + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.events.event import InputEvent +from flink_agents.api.function import JavaFunction + +JAVA_HANDLER_QUALNAME = "org.apache.flink.agents.resource.test.JavaActionHandler" +JAVA_HANDLER_METHOD = "multiplyByTwo" +JAVA_HANDLER_PARAMETER_TYPES = [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", +] + + +class PythonAgentWithJavaActionAgent(Agent): + """Python agent that dispatches into ``JavaActionHandler.multiplyByTwo``.""" + + def __init__(self) -> None: + """Create a PythonAgentWithJavaActionAgent.""" + super().__init__() + self.add_action( + name="multiply_by_two", + events=[InputEvent.EVENT_TYPE], + func=JavaFunction( + qualname=JAVA_HANDLER_QUALNAME, + method_name=JAVA_HANDLER_METHOD, + parameter_types=JAVA_HANDLER_PARAMETER_TYPES, + ), + ) + + +class SingleKeySelector(KeySelector): + """Mirror of Java ``JavaAgentWithPythonActionAgent.SingleKeySelector``.""" + + def get_key(self, value: int) -> int: + """Force all records onto a single key.""" + return 0 diff --git a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_agent_with_java_action_test.py b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_agent_with_java_action_test.py new file mode 100644 index 000000000..0070675bc --- /dev/null +++ b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/python_agent_with_java_action_test.py @@ -0,0 +1,99 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""E2E mirror of ``JavaAgentWithPythonActionTest``: Python agent + Java action body.""" + +import os +import sysconfig +from pathlib import Path + +import pytest +from pyflink.common import Configuration, Encoder +from pyflink.common.typeinfo import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.datastream.connectors.file_system import StreamingFileSink + +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.e2e_tests.e2e_tests_resource_cross_language.python_agent_with_java_action import ( + PythonAgentWithJavaActionAgent, + SingleKeySelector, +) + +current_dir = Path(__file__).parent +_REPO_ROOT = current_dir.parent.parent.parent.parent +_TEST_JAR = ( + _REPO_ROOT + / "e2e-test" + / "flink-agents-end-to-end-tests-resource-cross-language" + / "target" + / "flink-agents-end-to-end-tests-resource-cross-language-0.3-SNAPSHOT-tests.jar" +) + +os.environ["PYTHONPATH"] = sysconfig.get_paths()["purelib"] + + +@pytest.mark.skipif( + not _TEST_JAR.is_file(), + reason=( + "Cross-language test-jar is missing; run " + "'mvn package -DskipTests -pl e2e-test/" + "flink-agents-end-to-end-tests-resource-cross-language' first." + ), +) +def test_python_agent_dispatches_java_action_body(tmp_path: Path) -> None: + config = Configuration() + config.set_string("python.pythonpath", sysconfig.get_paths()["purelib"]) + env = StreamExecutionEnvironment.get_execution_environment(config) + env.set_parallelism(1) + env.add_jars(f"file://{_TEST_JAR}") + + input_stream = env.from_collection([1, 2, 3, 4, 5], type_info=Types.LONG()).map( + lambda x: x + ) + + agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env) + output_datastream = ( + agents_env.from_datastream( + input=input_stream, key_selector=SingleKeySelector() + ) + .apply(PythonAgentWithJavaActionAgent()) + .to_datastream(Types.LONG()) + ) + + result_dir = tmp_path / "results" + result_dir.mkdir(parents=True, exist_ok=True) + output_datastream.map(lambda x: str(x), Types.STRING()).add_sink( + StreamingFileSink.for_row_format( + base_path=str(result_dir.absolute()), + encoder=Encoder.simple_string_encoder(), + ).build() + ) + + agents_env.execute() + + actual: list[int] = [] + for file in result_dir.iterdir(): + if file.is_dir(): + for child in file.iterdir(): + with child.open() as f: + actual.extend(int(line.strip()) for line in f if line.strip()) + elif file.is_file(): + with file.open() as f: + actual.extend(int(line.strip()) for line in f if line.strip()) + + actual.sort() + assert actual == [2, 4, 6, 8, 10], f"unexpected outputs: {actual}" diff --git a/python/flink_agents/plan/actions/action.py b/python/flink_agents/plan/actions/action.py index a059df9a2..96f2f5f0a 100644 --- a/python/flink_agents/plan/actions/action.py +++ b/python/flink_agents/plan/actions/action.py @@ -72,15 +72,25 @@ def __serialize_config(self, config: Dict[str, Any]) -> Dict[str, Any] | None: @model_validator(mode="before") def __custom_deserialize(self) -> "Action": config = self["config"] - if config is not None and _CONFIG_TYPE in config: - self["config"].pop(_CONFIG_TYPE) + if config is None or _CONFIG_TYPE not in config: + return self + config_type = self["config"].pop(_CONFIG_TYPE) + if config_type == "java": for name, value in config.items(): - try: - module = importlib.import_module(value[0]) - clazz = getattr(module, value[1]) - self["config"][name] = clazz.model_validate(value[2]) - except Exception: # noqa : PERF203 - self["config"][name] = value + if ( + isinstance(value, dict) + and "@class" in value + and "value" in value + ): + self["config"][name] = value["value"] + return self + for name, value in config.items(): + try: + module = importlib.import_module(value[0]) + clazz = getattr(module, value[1]) + self["config"][name] = clazz.model_validate(value[2]) + except Exception: # noqa : PERF203 + self["config"][name] = value return self def __init__( diff --git a/python/flink_agents/plan/agent_plan.py b/python/flink_agents/plan/agent_plan.py index 01e946e4a..24aca0caf 100644 --- a/python/flink_agents/plan/agent_plan.py +++ b/python/flink_agents/plan/agent_plan.py @@ -226,6 +226,28 @@ def _resolve_event_type(evt: Any) -> str: raise ValueError(msg) +def _action_marker(value: Any) -> tuple | None: + """Return ``(inner_callable, listen_events, target)`` if ``value`` is an @action. + + ``@action`` may set ``_listen_events`` on the outer wrapper (when ``@action`` + is the outer decorator) or on ``__func__`` (when ``@staticmethod`` is outer + and ``@action`` inner). Accept either by checking both candidates. + """ + inner = value.__func__ if isinstance(value, staticmethod) else value + if not callable(inner): + return None + marker = ( + value + if hasattr(value, "_listen_events") + else inner + if hasattr(inner, "_listen_events") + else None + ) + if marker is None: + return None + return inner, marker._listen_events, getattr(marker, "_target", None) + + def _get_actions(agent: Agent) -> List[Action]: """Extract all registered agent actions from an agent. @@ -239,30 +261,37 @@ def _get_actions(agent: Agent) -> List[Action]: List[Action] List of Action defined in the agent. """ - actions = [] - for name, value in agent.__class__.__dict__.items(): - if isinstance(value, staticmethod) and hasattr(value, "_listen_events"): - actions.append( - Action( - name=name, - exec=PythonFunction.from_callable(value.__func__), - listen_event_types=[ - _resolve_event_type(et) - for et in value._listen_events - ], - ) - ) - elif callable(value) and hasattr(value, "_listen_events"): - actions.append( - Action( - name=name, - exec=PythonFunction.from_callable(value), - listen_event_types=[ - _resolve_event_type(et) - for et in value._listen_events - ], + # __dict__ skips inherited @action methods; reject loudly. + agent_class = agent.__class__ + for parent in agent_class.__mro__[1:]: + if parent is Agent or parent is object: + break + for parent_name, parent_value in parent.__dict__.items(): + if _action_marker(parent_value) is not None: + msg = ( + f"Inherited @action '{parent.__qualname__}.{parent_name}' is " + f"not supported; declare on the concrete agent." ) + raise RuntimeError(msg) + + actions = [] + for name, value in agent_class.__dict__.items(): + marker = _action_marker(value) + if marker is None: + continue + inner, listen_events, target = marker + exec_ = ( + _to_plan_function(target) + if target is not None + else PythonFunction.from_callable(inner) + ) + actions.append( + Action( + name=name, + exec=exec_, + listen_event_types=[_resolve_event_type(et) for et in listen_events], ) + ) for name, action_tuple in agent.actions.items(): actions.append( Action( diff --git a/python/flink_agents/plan/function.py b/python/flink_agents/plan/function.py index 55086414a..28bc5089b 100644 --- a/python/flink_agents/plan/function.py +++ b/python/flink_agents/plan/function.py @@ -272,17 +272,29 @@ def set_java_resource_adapter(self, adapter: Any) -> None: def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: """Invoke the Java method via the JVM resource adapter. - LLM tool calls always arrive as keyword arguments — positional - ``*args`` are ignored because the Java side reorders parameters - by name via reflection. + Positional args route to ``invokeJavaAction`` (action dispatch); + keyword args route to ``invokeJavaTool`` (LLM tool dispatch). """ if self._j_resource_adapter is None: msg = ( "JavaFunction requires the JVM resource adapter; not set " - "on this descriptor. The runtime should inject it via " + "on this descriptor. The runtime injects it via " "set_java_resource_adapter before invocation." ) raise RuntimeError(msg) + if args and kwargs: + msg = ( + "JavaFunction does not support mixing positional and keyword " + "args; pass one or the other (positional = action, kwargs = tool)." + ) + raise TypeError(msg) + if args: + return self._j_resource_adapter.invokeJavaAction( + self.qualname, + self.method_name, + self.parameter_types, + list(args), + ) return self._j_resource_adapter.invokeJavaTool( self.qualname, self.method_name, @@ -291,7 +303,15 @@ def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: ) def check_signature(self, *args: Tuple[Any, ...]) -> None: - """Check function signature is legal or not.""" + """Check declared Java parameter arity matches expectations.""" + if len(self.parameter_types) != len(args): + msg = ( + f"JavaFunction {self.qualname}.{self.method_name} declares " + f"{len(self.parameter_types)} parameter type(s) " + f"({self.parameter_types!r}) but the action contract " + f"expects {len(args)}." + ) + raise TypeError(msg) def call_python_function(module: str, qualname: str, func_args: Tuple[Any, ...]) -> Any: diff --git a/python/flink_agents/plan/tests/test_action.py b/python/flink_agents/plan/tests/test_action.py index e9a636ee9..5a75f7186 100644 --- a/python/flink_agents/plan/tests/test_action.py +++ b/python/flink_agents/plan/tests/test_action.py @@ -92,3 +92,31 @@ def test_action_deserialize(action: Action) -> None: func = action.exec assert func.module == "flink_agents.plan.tests.test_action" assert func.qualname == "legal_signature" + + +def test_action_deserialize_java_shape_config_unwraps_primitives() -> None: + json_str = json.dumps( + { + "name": "legal", + "exec": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.tests.test_action", + "qualname": "legal_signature", + }, + "listen_event_types": ["_input_event"], + "config": { + "__config_type__": "java", + "timeout_sec": {"@class": "java.lang.Integer", "value": 30}, + "enabled": {"@class": "java.lang.Boolean", "value": True}, + "rate": {"@class": "java.lang.Double", "value": 1.5}, + "label": {"@class": "java.lang.String", "value": "fast"}, + }, + } + ) + action = Action.model_validate_json(json_str) + assert action.config == { + "timeout_sec": 30, + "enabled": True, + "rate": 1.5, + "label": "fast", + } diff --git a/python/flink_agents/plan/tests/test_agent_plan.py b/python/flink_agents/plan/tests/test_agent_plan.py index 001b1d4de..801223e21 100644 --- a/python/flink_agents/plan/tests/test_agent_plan.py +++ b/python/flink_agents/plan/tests/test_agent_plan.py @@ -36,6 +36,7 @@ BaseEmbeddingModelSetup, ) from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.function import JavaFunction from flink_agents.api.resource import ResourceDescriptor, ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.api.vector_stores.vector_store import ( @@ -84,6 +85,95 @@ def test_to_agent_invalid_signature() -> None: AgentPlan.from_agent(agent, AgentConfiguration()) +def test_builtin_actions_are_python_native_after_compile() -> None: + agent_plan = AgentPlan.from_agent(AgentForTest(), AgentConfiguration()) + + for name in ("chat_model_action", "tool_call_action", "context_retrieval_action"): + action = agent_plan.actions[name] + assert isinstance(action.exec, PythonFunction) + + +class AgentWithConventionalDecoratorOrder(Agent): + """`@staticmethod` outer, `@action` inner — the conventional Python order. + + The decorator stack puts ``_listen_events`` on the inner function (i.e. + ``staticmethod.__func__``) rather than on the staticmethod wrapper, so + ``_get_actions`` must unwrap before inspecting attributes. + """ + + @staticmethod + @action(InputEvent.EVENT_TYPE) + def handle(event: Event, ctx: RunnerContext) -> None: + ctx.send_event(OutputEvent(output=InputEvent.from_event(event).input)) + + +def test_conventional_staticmethod_outer_decorator_order_is_registered() -> None: + plan = AgentPlan.from_agent( + AgentWithConventionalDecoratorOrder(), AgentConfiguration() + ) + actions = plan.get_actions(InputEvent.EVENT_TYPE) + assert len(actions) == 1, ( + "Action defined with `@staticmethod` outer / `@action` inner was silently " + "dropped — `_get_actions` should unwrap the staticmethod before checking " + "for `_listen_events`." + ) + assert actions[0].name == "handle" + + +class _BaseAgentWithInheritedAction(Agent): + """Base class with an @action — used to verify the inheritance guard.""" + + @action(InputEvent.EVENT_TYPE) + @staticmethod + def shared_action(event: Event, ctx: RunnerContext) -> None: + ctx.send_event(OutputEvent(output="shared")) + + +class _ConcreteAgentInheritingAction(_BaseAgentWithInheritedAction): + """Concrete agent that inherits ``shared_action`` from the base class.""" + + +def test_action_inherited_from_parent_agent_class_is_rejected() -> None: + with pytest.raises(RuntimeError, match="Inherited @action") as exc: + AgentPlan.from_agent(_ConcreteAgentInheritingAction(), AgentConfiguration()) + assert "shared_action" in str(exc.value) + assert "_BaseAgentWithInheritedAction" in str(exc.value) + + +_JAVA_HANDLER_QUALNAME = ( + "org.apache.flink.agents.runtime.operator." + "CrossLanguageActionRuntimeTest$Handlers" +) + + +class AgentWithCrossLanguageDecoratedAction(Agent): + @action( + InputEvent.EVENT_TYPE, + target=JavaFunction( + qualname=_JAVA_HANDLER_QUALNAME, + method_name="handleInput", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ), + ) + @staticmethod + def handle(event: Event, ctx: RunnerContext) -> None: + msg = "cross-language stub" + raise NotImplementedError(msg) + + +def test_decorated_action_with_target_compiles_to_plan_java_function() -> None: + plan = AgentPlan.from_agent( + AgentWithCrossLanguageDecoratedAction(), AgentConfiguration() + ) + action = plan.actions["handle"] + assert action.exec.qualname == _JAVA_HANDLER_QUALNAME + assert action.exec.method_name == "handleInput" + assert action.listen_event_types == [InputEvent.EVENT_TYPE] + + class MyEvent(Event): """Event for testing purposes.""" diff --git a/python/flink_agents/plan/tests/test_agent_plan_cross_language.py b/python/flink_agents/plan/tests/test_agent_plan_cross_language.py new file mode 100644 index 000000000..9a4b6d347 --- /dev/null +++ b/python/flink_agents/plan/tests/test_agent_plan_cross_language.py @@ -0,0 +1,384 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Layer B plan-compile tests for cross-language ``Function`` descriptors (Python side).""" + +import json +import os +from pathlib import Path + +import pytest + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.events.event import Event, InputEvent +from flink_agents.api.function import ( + JavaFunction as ApiJavaFunction, +) +from flink_agents.api.function import ( + PythonFunction as ApiPythonFunction, +) +from flink_agents.api.runner_context import RunnerContext +from flink_agents.plan.agent_plan import AgentPlan +from flink_agents.plan.configuration import AgentConfiguration +from flink_agents.plan.function import ( + JavaFunction as PlanJavaFunction, +) +from flink_agents.plan.function import ( + PythonFunction as PlanPythonFunction, +) + +# python/flink_agents/plan/tests/test_*.py -> repo root is parents[4]. +_REPO_ROOT = Path(__file__).resolve().parents[4] +_SNAPSHOT_DIR = _REPO_ROOT / "e2e-test" / "cross-language-agent-plan-snapshots" + + +def _regenerate_enabled() -> bool: + return os.environ.get("REGENERATE_SNAPSHOTS", "").lower() in {"1", "true", "yes"} + + +def _plan_dump_json(plan: AgentPlan) -> str: + """Stable JSON form of an AgentPlan, indented for diff-friendliness.""" + return plan.model_dump_json(serialize_as_any=True, indent=2) + + +def _dummy_action(event: Event, ctx: RunnerContext) -> None: + """Plain Python callable referenced by Python-target plans.""" + + +def _make_java_function_descriptor() -> ApiJavaFunction: + return ApiJavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + + +def _make_python_function_descriptor() -> ApiPythonFunction: + return ApiPythonFunction( + module=_dummy_action.__module__, + qualname=_dummy_action.__qualname__, + ) + + +# ── api → plan promotion (Python side) ────────────────────────────────── + + +def test_compile_agent_with_python_function_descriptor() -> None: + """ApiPythonFunction added via add_action becomes plan PythonFunction.""" + agent = Agent() + pf = _make_python_function_descriptor() + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=pf) + + plan = AgentPlan.from_agent(agent, AgentConfiguration()) + action = plan.actions["act"] + + assert isinstance(action.exec, PlanPythonFunction), ( + f"Expected plan PythonFunction, got {type(action.exec).__name__}" + ) + assert action.exec.module == pf.module + assert action.exec.qualname == pf.qualname + assert action.listen_event_types == [InputEvent.EVENT_TYPE] + + +def test_compile_agent_with_java_function_descriptor() -> None: + """ApiJavaFunction added via add_action becomes plan JavaFunction. + + Python's ``_to_plan_function`` does NOT resolve the Java class — it + keeps ``parameter_types`` as opaque strings. This is the documented + asymmetry from Java, which calls ``Class.forName`` at this point. + """ + agent = Agent() + jf = _make_java_function_descriptor() + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=jf) + + plan = AgentPlan.from_agent(agent, AgentConfiguration()) + action = plan.actions["act"] + + assert isinstance(action.exec, PlanJavaFunction), ( + f"Expected plan JavaFunction, got {type(action.exec).__name__}" + ) + assert action.exec.qualname == jf.qualname + assert action.exec.method_name == jf.method_name + assert list(action.exec.parameter_types) == list(jf.parameter_types) + assert action.listen_event_types == [InputEvent.EVENT_TYPE] + + +def test_python_plan_compile_does_not_validate_java_class_exists() -> None: + """Python plan compile must not require the Java class to exist locally. + + Cross-language descriptors are pure data on the Python side; class + resolution happens later on the Java side at runtime. A + nonexistent FQN must compile cleanly here. + """ + agent = Agent() + fake = ApiJavaFunction( + qualname="com.does.not.Exist", + method_name="ghost", + parameter_types=["java.lang.String", "int"], + ) + agent.add_action(name="act", events=[InputEvent.EVENT_TYPE], func=fake) + + plan = AgentPlan.from_agent(agent, AgentConfiguration()) + assert plan.actions["act"].exec.qualname == "com.does.not.Exist" + + +def test_compile_preserves_action_config() -> None: + agent = Agent() + agent.add_action( + name="act", + events=[InputEvent.EVENT_TYPE], + func=_make_python_function_descriptor(), + timeout_sec=30, + retry=2, + ) + + plan = AgentPlan.from_agent(agent, AgentConfiguration()) + assert plan.actions["act"].config == {"timeout_sec": 30, "retry": 2} + + +def test_compile_rejects_unknown_function_descriptor() -> None: + """Sanity: an unknown ApiFunction subclass should be refused.""" + from flink_agents.api.function import Function as ApiFunction + + class WeirdFunction(ApiFunction): + pass + + agent = Agent() + agent.add_action( + name="act", events=[InputEvent.EVENT_TYPE], func=WeirdFunction() + ) + + with pytest.raises(TypeError, match="Unsupported function descriptor"): + AgentPlan.from_agent(agent, AgentConfiguration()) + + +# ── Plan JSON shape (Python side) ─────────────────────────────────────── + + +def _java_action_plan() -> AgentPlan: + """Minimal plan with a single cross-language Java-target action.""" + agent = Agent() + agent.add_action( + name="handle", + events=[InputEvent.EVENT_TYPE], + func=_make_java_function_descriptor(), + ) + return AgentPlan.from_agent(agent, AgentConfiguration()) + + +def _python_action_plan() -> AgentPlan: + """Minimal plan with a single same-language Python-target action.""" + agent = Agent() + agent.add_action( + name="handle", + events=[InputEvent.EVENT_TYPE], + func=_make_python_function_descriptor(), + ) + return AgentPlan.from_agent(agent, AgentConfiguration()) + + +def test_python_plan_with_java_action_has_expected_exec_shape() -> None: + """Pin the wire shape of a Java-target action's ``exec`` block.""" + plan = _java_action_plan() + parsed = json.loads(_plan_dump_json(plan)) + exec_block = parsed["actions"]["handle"]["exec"] + + assert exec_block == { + "func_type": "JavaFunction", + "qualname": "com.example.Handlers", + "method_name": "handle", + "parameter_types": [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + } + + +def test_python_plan_with_python_action_has_expected_exec_shape() -> None: + """Pin the wire shape of a Python-target action's ``exec`` block.""" + plan = _python_action_plan() + parsed = json.loads(_plan_dump_json(plan)) + exec_block = parsed["actions"]["handle"]["exec"] + + assert exec_block == { + "func_type": "PythonFunction", + "module": _dummy_action.__module__, + "qualname": _dummy_action.__qualname__, + } + + +# ── Plan JSON round-trip (Python side) ────────────────────────────────── + + +def test_python_plan_with_java_action_round_trips_through_json() -> None: + plan = _java_action_plan() + json_str = _plan_dump_json(plan) + restored = AgentPlan.model_validate_json(json_str) + + action = restored.actions["handle"] + assert isinstance(action.exec, PlanJavaFunction) + assert action.exec.qualname == "com.example.Handlers" + assert action.exec.method_name == "handle" + assert list(action.exec.parameter_types) == [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ] + + +def test_python_plan_with_python_action_round_trips_through_json() -> None: + plan = _python_action_plan() + json_str = _plan_dump_json(plan) + restored = AgentPlan.model_validate_json(json_str) + + action = restored.actions["handle"] + assert isinstance(action.exec, PlanPythonFunction) + assert action.exec.module == _dummy_action.__module__ + assert action.exec.qualname == _dummy_action.__qualname__ + + +# ── Cross-language snapshot (Python writes / Java reads) ──────────────── + + +def test_regenerate_python_plan_with_java_action_snapshot() -> None: + if not _regenerate_enabled(): + pytest.skip("Set REGENERATE_SNAPSHOTS=1 to refresh.") + + target = _SNAPSHOT_DIR / "python" / "agent_plan_with_java_action.json" + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(_plan_dump_json(_java_action_plan()) + "\n") + + +def test_python_plan_with_java_action_snapshot_is_stable() -> None: + snapshot_path = _SNAPSHOT_DIR / "python" / "agent_plan_with_java_action.json" + assert snapshot_path.exists(), ( + f"Python plan snapshot missing from {snapshot_path}. " + f"Regenerate with REGENERATE_SNAPSHOTS=1 and commit alongside the test." + ) + + actual = json.loads(_plan_dump_json(_java_action_plan())) + expected = json.loads(snapshot_path.read_text()) + assert actual == expected, ( + "Python plan-with-Java-action JSON drifted from committed snapshot." + ) + + +def test_python_can_deserialize_java_plan_with_python_action() -> None: + """Java produces a plan referencing a Python action; Python must read it back.""" + snapshot = _SNAPSHOT_DIR / "java" / "agent_plan_with_python_action.json" + assert snapshot.exists(), ( + f"Java plan snapshot missing from {snapshot}. " + f"Regenerate the Java side with -Dregenerate.snapshots=true " + f"and commit alongside this test." + ) + + json_str = snapshot.read_text() + restored = AgentPlan.model_validate_json(json_str) + + action = restored.actions["handle"] + assert isinstance(action.exec, PlanPythonFunction), ( + f"Expected plan PythonFunction, got {type(action.exec).__name__}" + ) + assert action.exec.module == _dummy_action.__module__ + assert action.exec.qualname == _dummy_action.__qualname__ + + +def test_python_plan_with_java_action_matches_runtime_operator_wire_shape() -> None: + handler_qualname = ( + "org.apache.flink.agents.runtime.operator." + "CrossLanguageActionRuntimeTest$Handlers" + ) + expected_parameter_types = [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ] + + agent = Agent() + agent.add_action( + name="handle", + events=[InputEvent.EVENT_TYPE], + func=ApiJavaFunction( + qualname=handler_qualname, + method_name="handleInput", + parameter_types=expected_parameter_types, + ), + ) + + plan = AgentPlan.from_agent(agent, AgentConfiguration()) + emitted = json.loads(_plan_dump_json(plan)) + + handle_block = emitted["actions"]["handle"] + assert handle_block["name"] == "handle" + assert handle_block["listen_event_types"] == [InputEvent.EVENT_TYPE] + assert handle_block["config"] is None + assert handle_block["exec"] == { + "func_type": "JavaFunction", + "qualname": handler_qualname, + "method_name": "handleInput", + "parameter_types": expected_parameter_types, + } + assert emitted["actions_by_event"][InputEvent.EVENT_TYPE] == ["handle"] + + +def test_python_preserves_conf_data_types_and_event_ordering() -> None: + json_str = json.dumps( + { + "actions": { + "first": { + "name": "first", + "exec": { + "func_type": "PythonFunction", + "module": _dummy_action.__module__, + "qualname": _dummy_action.__qualname__, + }, + "listen_event_types": [InputEvent.EVENT_TYPE], + "config": None, + }, + "second": { + "name": "second", + "exec": { + "func_type": "PythonFunction", + "module": _dummy_action.__module__, + "qualname": _dummy_action.__qualname__, + }, + "listen_event_types": [InputEvent.EVENT_TYPE], + "config": None, + }, + }, + "actions_by_event": {InputEvent.EVENT_TYPE: ["first", "second"]}, + "resource_providers": {}, + "config": { + "conf_data": { + "k_int": 1, + "k_float": 1.5, + "k_bool": True, + "k_str": "v1", + } + }, + } + ) + restored = AgentPlan.model_validate_json(json_str) + + assert restored.config.conf_data == { + "k_int": 1, + "k_float": 1.5, + "k_bool": True, + "k_str": "v1", + } + assert restored.actions_by_event[InputEvent.EVENT_TYPE] == ["first", "second"] diff --git a/python/flink_agents/plan/tests/test_function.py b/python/flink_agents/plan/tests/test_function.py index 6ce4c8536..5c9ec2f79 100644 --- a/python/flink_agents/plan/tests/test_function.py +++ b/python/flink_agents/plan/tests/test_function.py @@ -24,7 +24,9 @@ import pytest from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.function import ( + JavaFunction, PythonFunction, _is_function_cacheable, call_python_function, @@ -113,6 +115,48 @@ def test_function_signature_generic_type_mismatch() -> None: func.check_signature(Tuple[str, ...], Dict[str, Any]) +# ── JavaFunction signature checks ─────────────────────────────────────── + + +def _java_action_function() -> JavaFunction: + return JavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + + +def test_java_function_signature_matching_arity_passes() -> None: + _java_action_function().check_signature(Event, RunnerContext) + + +def test_java_function_signature_arity_too_few_raises() -> None: + func = JavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=["org.apache.flink.agents.api.Event"], + ) + with pytest.raises(TypeError, match="declares 1 parameter type"): + func.check_signature(Event, RunnerContext) + + +def test_java_function_signature_arity_too_many_raises() -> None: + func = JavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + "java.lang.String", + ], + ) + with pytest.raises(TypeError, match="declares 3 parameter type"): + func.check_signature(Event, RunnerContext) + + current_dir = Path(__file__).parent diff --git a/python/flink_agents/plan/tests/test_resource_provider.py b/python/flink_agents/plan/tests/test_resource_provider.py index 355c76435..60d631853 100644 --- a/python/flink_agents/plan/tests/test_resource_provider.py +++ b/python/flink_agents/plan/tests/test_resource_provider.py @@ -21,7 +21,11 @@ import pytest from flink_agents.api.resource import Resource, ResourceDescriptor, ResourceType -from flink_agents.plan.resource_provider import PythonResourceProvider, ResourceProvider +from flink_agents.plan.resource_provider import ( + JavaResourceProvider, + PythonResourceProvider, + ResourceProvider, +) current_dir = Path(__file__).parent @@ -68,3 +72,32 @@ def test_python_resource_provider_deserialize( expected_json ) assert resource_provider == expected_resource_provider + + +def test_python_can_deserialize_java_resource_provider_wire_shape() -> None: + json_str = json.dumps( + { + "name": "bedrock_chat", + "type": "chat_model", + "descriptor": { + "target_module": "", + "target_clazz": "org.apache.flink.agents.integrations.chatmodels.bedrock.BedrockChatModelSetup", + "arguments": { + "java_clazz": "org.apache.flink.agents.integrations.chatmodels.bedrock.BedrockChatModelSetup", + "model": "anthropic.claude-3-haiku", + "max_tokens": 1024, + }, + }, + "__resource_provider_type__": "JavaResourceProvider", + } + ) + provider = JavaResourceProvider.model_validate_json(json_str) + + assert provider.name == "bedrock_chat" + assert provider.type == ResourceType.CHAT_MODEL + assert provider.descriptor.target_module == "" + assert provider.descriptor.target_clazz == ( + "org.apache.flink.agents.integrations.chatmodels.bedrock.BedrockChatModelSetup" + ) + assert provider.descriptor.arguments["model"] == "anthropic.claude-3-haiku" + assert provider.descriptor.arguments["max_tokens"] == 1024 diff --git a/python/flink_agents/runtime/tests/test_local_runner_cross_language.py b/python/flink_agents/runtime/tests/test_local_runner_cross_language.py new file mode 100644 index 000000000..0bb1502e4 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_local_runner_cross_language.py @@ -0,0 +1,163 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Action-dispatch tests via Python's local runner.""" + +from typing import Any, List, Tuple + +import pytest + +from flink_agents.api.agents.agent import Agent +from flink_agents.api.events.event import Event, InputEvent, OutputEvent +from flink_agents.api.function import JavaFunction as ApiJavaFunction +from flink_agents.api.runner_context import RunnerContext +from flink_agents.plan.configuration import AgentConfiguration +from flink_agents.plan.function import JavaFunction as PlanJavaFunction +from flink_agents.runtime.local_runner import LocalRunner + + +def echo_action(event: Event, ctx: RunnerContext) -> None: + value = InputEvent.from_event(event).input + ctx.send_event(OutputEvent(output=value)) + + +def _make_java_function_descriptor() -> ApiJavaFunction: + return ApiJavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + + +def test_local_runner_dispatches_python_function_action() -> None: + agent = Agent() + agent.add_action( + name="echo", events=[InputEvent.EVENT_TYPE], func=echo_action + ) + + runner = LocalRunner(agent, AgentConfiguration()) + runner.run(key="k1", value="hello") + + assert runner.get_outputs() == [{"k1": "hello"}] + + +def test_local_runner_dispatch_of_java_function_action_fails_without_jvm_bridge() -> None: + agent = Agent() + agent.add_action( + name="handle", + events=[InputEvent.EVENT_TYPE], + func=_make_java_function_descriptor(), + ) + + runner = LocalRunner(agent, AgentConfiguration()) + with pytest.raises(RuntimeError, match="JVM resource adapter"): + runner.run(key="k1", value="hello") + + +class _RecordingJavaAdapter: + def __init__(self) -> None: + self.action_calls: List[Tuple[str, str, List[str], list]] = [] + self.tool_calls: List[Tuple[str, str, List[str], dict]] = [] + + def invokeJavaAction( + self, + qualname: str, + method_name: str, + parameter_types: List[str], + args: list, + ) -> Any: + self.action_calls.append( + (qualname, method_name, list(parameter_types), list(args)) + ) + return None + + def invokeJavaTool( + self, + qualname: str, + method_name: str, + parameter_types: List[str], + kwargs: dict, + ) -> Any: + self.tool_calls.append( + (qualname, method_name, list(parameter_types), dict(kwargs)) + ) + return None + + +def test_plan_java_function_routes_positional_to_action_dispatch() -> None: + plan_fn = PlanJavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=[ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ], + ) + adapter = _RecordingJavaAdapter() + plan_fn.set_java_resource_adapter(adapter) + + sentinel_event = InputEvent(input="payload") + sentinel_ctx = object() + + plan_fn(sentinel_event, sentinel_ctx) + + assert not adapter.tool_calls + assert len(adapter.action_calls) == 1 + qualname, method_name, parameter_types, args = adapter.action_calls[0] + assert qualname == "com.example.Handlers" + assert method_name == "handle" + assert parameter_types == [ + "org.apache.flink.agents.api.Event", + "org.apache.flink.agents.api.context.RunnerContext", + ] + assert args == [sentinel_event, sentinel_ctx] + + +def test_plan_java_function_routes_kwargs_to_tool_dispatch() -> None: + plan_fn = PlanJavaFunction( + qualname="com.example.Tools", + method_name="multiply", + parameter_types=["int", "int"], + ) + adapter = _RecordingJavaAdapter() + plan_fn.set_java_resource_adapter(adapter) + + plan_fn(a=2, b=3) + + assert not adapter.action_calls + assert len(adapter.tool_calls) == 1 + _, _, _, kwargs = adapter.tool_calls[0] + assert kwargs == {"a": 2, "b": 3} + + +def test_plan_java_function_rejects_mixed_positional_and_keyword_args() -> None: + plan_fn = PlanJavaFunction( + qualname="com.example.Handlers", + method_name="handle", + parameter_types=["org.apache.flink.agents.api.Event"], + ) + adapter = _RecordingJavaAdapter() + plan_fn.set_java_resource_adapter(adapter) + + with pytest.raises(TypeError, match="mixing positional and keyword"): + plan_fn(InputEvent(input="x"), extra=1) + + assert not adapter.action_calls + assert not adapter.tool_calls diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java index bc2aef6f7..7eecc5d26 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/EventRouter.java @@ -186,31 +186,33 @@ Event wrapToInputEvent(IN input, PythonActionExecutor pythonActionExecutor) thro /** * Extracts the downstream output payload from an output {@link Event}. * - *

For a Java {@link OutputEvent}, returns the payload directly. For Python-side output - * events (cross-language unified {@link Event} with output type), the event is JSON-serialized - * and handed to {@link PythonActionExecutor#getOutputFromOutputEvent(String)} for extraction. + *

Dispatch is by pipeline wire format, not action language: + * + *

    + *
  • Java pipelines ({@code inputIsJava}) emit the raw payload directly. + *
  • Python pipelines re-encode through {@link + * PythonActionExecutor#getOutputFromOutputEvent(String)} so the downstream Python sink + * receives cloudpickle bytes. + *
* * @param event the output event (must satisfy {@link EventUtil#isOutputEvent(Event)}). - * @param pythonActionExecutor the Python action executor (used only for Python output events). + * @param pythonActionExecutor used only on Python pipelines. * @return the typed output payload. */ @SuppressWarnings("unchecked") OUT getOutputFromOutputEvent(Event event, PythonActionExecutor pythonActionExecutor) { checkState(EventUtil.isOutputEvent(event)); - if (event instanceof OutputEvent) { - return (OUT) ((OutputEvent) event).getOutput(); - } else { - // Python output events arrive as unified Event with type "_output_event". - // Pass the JSON representation to Python for extraction. - try { - String eventJson = new ObjectMapper().writeValueAsString(event); - Object outputFromOutputEvent = - pythonActionExecutor.getOutputFromOutputEvent(eventJson); - return (OUT) outputFromOutputEvent; - } catch (Exception e) { - throw new IllegalStateException( - "Failed to extract output from event: " + event.getType(), e); - } + OutputEvent typedEvent = + (event instanceof OutputEvent) ? (OutputEvent) event : OutputEvent.fromEvent(event); + if (inputIsJava) { + return (OUT) typedEvent.getOutput(); + } + try { + String eventJson = new ObjectMapper().writeValueAsString(typedEvent); + return (OUT) pythonActionExecutor.getOutputFromOutputEvent(eventJson); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to extract output from event: " + event.getType(), e); } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java index 3c00d96dc..0c37820d4 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java @@ -31,6 +31,7 @@ import pemja.core.PythonInterpreter; import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -180,6 +181,25 @@ public Object invokeJavaTool( return response.getResult(); } + /** Invoke a Java static action method with positional arguments from Python. */ + public Object invokeJavaAction( + String className, + String methodName, + List parameterTypes, + List arguments) + throws Exception { + Method method = resolveMethod(className, methodName, parameterTypes); + if (!Modifier.isStatic(method.getModifiers())) { + throw new IllegalArgumentException( + "JavaAction target must be a static method. Got instance method: " + + className + + "#" + + methodName); + } + Object[] args = arguments == null ? new Object[0] : arguments.toArray(); + return method.invoke(null, args); + } + private Method resolveMethod(String className, String methodName, List parameterTypes) throws ClassNotFoundException, NoSuchMethodException { ClassLoader classLoader = diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/CrossLanguageActionRuntimeTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/CrossLanguageActionRuntimeTest.java new file mode 100644 index 000000000..6d3eedd31 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/CrossLanguageActionRuntimeTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Layer F1 — feed a Python-shaped JSON plan into the operator harness and confirm a JavaFunction + * action body runs. Java→Python action dispatch goes through Pemja and is Layer F2 scope. + */ +public class CrossLanguageActionRuntimeTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + /** Public nested class so reflection can invoke {@code handleInput} without access errors. */ + public static class Handlers { + public static void handleInput(Event event, RunnerContext context) { + Long value = (Long) InputEvent.fromEvent(event).getInput(); + context.sendEvent(new OutputEvent(value * 2)); + } + } + + @Test + void operatorRunsJavaActionFromPythonShapedPlanJson() throws Exception { + String planJson = pythonShapedPlanJson(); + + AgentPlan plan = MAPPER.readValue(planJson, AgentPlan.class); + + Action handle = plan.getActions().get("handle"); + assertThat(handle).isNotNull(); + assertThat(handle.getExec()).isInstanceOf(JavaFunction.class); + JavaFunction execFn = (JavaFunction) handle.getExec(); + assertThat(execFn.getQualName()).isEqualTo(Handlers.class.getName()); + assertThat(execFn.getMethodName()).isEqualTo("handleInput"); + + try (KeyedOneInputStreamOperatorTestHarness testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory(plan, true), + (KeySelector) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator operator = + (ActionExecutionOperator) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(3L)); + operator.waitInFlightEventsFinished(); + + List> recordOutput = + (List>) testHarness.getRecordOutput(); + assertThat(recordOutput).hasSize(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo(6L); + } + } + + @Test + void pythonFunctionPlanDeserializesAndIsRecognizedByOperatorFactory() throws Exception { + // Pins deserialize + factory-accepts-plan without opening the operator (Pemja libpython + // load is Layer F2). + AgentPlan plan = planWithPythonAction(); + + Action handle = plan.getActions().get("py_handle"); + assertThat(handle).isNotNull(); + assertThat(handle.getExec()).isInstanceOf(PythonFunction.class); + + ActionExecutionOperatorFactory factory = new ActionExecutionOperatorFactory(plan, true); + assertThat(factory).isNotNull(); + } + + /** + * Wire format Python emits via {@code AgentPlan.model_dump_json}; pinned symmetric in Layer B. + */ + private static String pythonShapedPlanJson() { + String qualName = Handlers.class.getName(); + return "{" + + "\"actions\":{" + + " \"handle\":{" + + " \"name\":\"handle\"," + + " \"exec\":{" + + " \"func_type\":\"JavaFunction\"," + + " \"qualname\":\"" + + qualName + + "\"," + + " \"method_name\":\"handleInput\"," + + " \"parameter_types\":[" + + " \"org.apache.flink.agents.api.Event\"," + + " \"org.apache.flink.agents.api.context.RunnerContext\"" + + " ]" + + " }," + + " \"listen_event_types\":[\"_input_event\"]," + + " \"config\":null" + + " }" + + "}," + + "\"actions_by_event\":{" + + " \"_input_event\":[\"handle\"]" + + "}," + + "\"resource_providers\":{}," + + "\"config\":{\"conf_data\":{}}" + + "}"; + } + + private static AgentPlan planWithPythonAction() throws Exception { + java.util.Map> actionsByEvent = new java.util.HashMap<>(); + java.util.Map actions = new java.util.HashMap<>(); + + PythonFunction pythonFn = + new PythonFunction( + "flink_agents.plan.tests.test_agent_plan_cross_language", "_dummy_action"); + Action act = + new Action( + "py_handle", + pythonFn, + java.util.Collections.singletonList(InputEvent.EVENT_TYPE)); + actions.put(act.getName(), act); + actionsByEvent.put(InputEvent.EVENT_TYPE, java.util.Collections.singletonList(act)); + + return new AgentPlan( + actions, + actionsByEvent, + new java.util.HashMap<>(), + new org.apache.flink.agents.plan.AgentConfiguration()); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java index 8f3cbea9c..97f994b42 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/EventRouterTest.java @@ -20,6 +20,7 @@ import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.EventContext; import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; import org.apache.flink.agents.api.listener.EventListener; import org.apache.flink.agents.api.logger.EventLogger; import org.apache.flink.agents.plan.AgentPlan; @@ -27,6 +28,7 @@ import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.agents.runtime.operator.queue.SegmentedQueue; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; import org.apache.flink.streaming.api.watermark.Watermark; import org.junit.jupiter.api.Test; import org.mockito.InOrder; @@ -42,8 +44,10 @@ import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** Contract tests for {@link EventRouter}. */ class EventRouterTest { @@ -199,6 +203,55 @@ void processEligibleWatermarksBlocksWhileSegmentHasKeys() throws Exception { assertThat(emitted).containsExactly(wm1); } + /** Java pipeline + typed OutputEvent — raw payload, no PythonActionExecutor. */ + @Test + void getOutputFromOutputEventReturnsRawPayloadForTypedOutputEventOnJavaPipeline() { + AgentPlan plan = new AgentPlan(new HashMap<>(), new HashMap<>()); + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ true); + PythonActionExecutor mockPython = mock(PythonActionExecutor.class); + + Object output = router.getOutputFromOutputEvent(new OutputEvent(42L), mockPython); + + assertThat(output).isEqualTo(42L); + verify(mockPython, never()).getOutputFromOutputEvent(any()); + } + + /** + * Java pipeline + unified Event with {@code _output_event} type (e.g. a Python action body + * emitted it on a Java pipeline) — reconstruct via {@link OutputEvent#fromEvent(Event)}, never + * round-trip through {@link PythonActionExecutor}. + */ + @Test + void getOutputFromOutputEventReturnsRawPayloadForUnifiedEventOnJavaPipeline() { + AgentPlan plan = new AgentPlan(new HashMap<>(), new HashMap<>()); + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ true); + PythonActionExecutor mockPython = mock(PythonActionExecutor.class); + + Map attrs = new HashMap<>(); + attrs.put("output", 84L); + Event unified = new Event(OutputEvent.EVENT_TYPE, attrs); + + Object output = router.getOutputFromOutputEvent(unified, mockPython); + + assertThat(output).isEqualTo(84L); + verify(mockPython, never()).getOutputFromOutputEvent(any()); + } + + /** Python pipeline — re-encode through PythonActionExecutor for the cloudpickle bytes. */ + @Test + void getOutputFromOutputEventRoundsTripsThroughPythonOnPythonPipeline() { + AgentPlan plan = new AgentPlan(new HashMap<>(), new HashMap<>()); + EventRouter router = new EventRouter<>(plan, /* inputIsJava */ false); + PythonActionExecutor mockPython = mock(PythonActionExecutor.class); + byte[] pickled = new byte[] {1, 2, 3}; + when(mockPython.getOutputFromOutputEvent(any())).thenReturn(pickled); + + Object output = router.getOutputFromOutputEvent(new OutputEvent(42L), mockPython); + + assertThat(output).isEqualTo(pickled); + verify(mockPython).getOutputFromOutputEvent(any()); + } + private static BuiltInMetrics makeMetrics() { FlinkAgentsMetricGroupImpl metricGroup = mock(FlinkAgentsMetricGroupImpl.class, RETURNS_DEEP_STUBS);