Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.agents.api.agents;

import org.apache.flink.agents.api.configuration.ConfigOption;
import org.apache.flink.api.common.state.StateTtlConfig;

public class AgentExecutionOptions {
public static final ConfigOption<Agent.ErrorHandlingStrategy> ERROR_HANDLING_STRATEGY =
Expand Down Expand Up @@ -47,4 +48,21 @@ public class AgentExecutionOptions {

public static final ConfigOption<Boolean> RAG_ASYNC =
new ConfigOption<>("rag.async", Boolean.class, true);

public static final ConfigOption<Long> SHORT_TERM_MEMORY_STATE_TTL_MS =
new ConfigOption<>("short-term-memory.state-ttl.ms", Long.class, 0L);

public static final ConfigOption<StateTtlConfig.UpdateType>
SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE =
new ConfigOption<>(
"short-term-memory.state-ttl.update-type",
StateTtlConfig.UpdateType.class,
StateTtlConfig.UpdateType.OnReadAndWrite);

public static final ConfigOption<StateTtlConfig.StateVisibility>
SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY =
new ConfigOption<>(
"short-term-memory.state-ttl.visibility",
StateTtlConfig.StateVisibility.class,
StateTtlConfig.StateVisibility.NeverReturnExpired);
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
* public static ResourceDesc openAIResponses() {
* return ResourceDescriptor.Builder.newBuilder(OpenAIResponsesModelConnection.class.getName())
* .addInitialArgument("api_key", System.getenv("OPENAI_API_KEY"))
* .addInitialArgument("api_base_url", System.getenv("OPENAI_API_URL"))
* .addInitialArgument("timeout", 120)
* .addInitialArgument("max_retries", 3)
* .build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public void setup(
public void open() throws Exception {
super.open();

stateManager.initializeKeyedStates(getRuntimeContext());
stateManager.initializeKeyedStates(getRuntimeContext(), agentPlan);
stateManager.initializeOperatorStates(getOperatorStateBackend());

resourceCache = new ResourceCache(agentPlan.getResourceProviders());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
package org.apache.flink.agents.runtime.operator;

import org.apache.flink.agents.api.Event;
import org.apache.flink.agents.api.agents.AgentExecutionOptions;
import org.apache.flink.agents.plan.AgentPlan;
import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
Expand All @@ -37,6 +40,8 @@

import javax.annotation.Nullable;

import java.time.Duration;

import static org.apache.flink.agents.runtime.utils.StateUtil.*;

/**
Expand All @@ -56,9 +61,9 @@
*
* <p>Lifecycle: instantiated by the operator's {@code initializeState()} (the Flink lifecycle runs
* {@code initializeState} before {@code open}). Both {@link
* #initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext)} and {@link
* #initializeOperatorStates(OperatorStateBackend)} are invoked later from the operator's {@code
* open()}. There is no explicit close — the underlying state handles are owned by Flink.
* #initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext, AgentPlan)} and
* {@link #initializeOperatorStates(OperatorStateBackend)} are invoked later from the operator's
* {@code open()}. There is no explicit close — the underlying state handles are owned by Flink.
*
* <p>Design constraint: package-private; no manager-to-manager held references. Cross-cutting data
* flows via method parameters (see for example {@link ActionTaskContextManager#transferContexts}
Expand Down Expand Up @@ -87,7 +92,9 @@ class OperatorStateManager {
*
* @param runtimeContext the operator's runtime context, used to obtain keyed state handles.
*/
void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext runtimeContext)
void initializeKeyedStates(
org.apache.flink.api.common.functions.RuntimeContext runtimeContext,
AgentPlan agentPlan)
throws Exception {
// init sensoryMemState
MapStateDescriptor<String, MemoryObjectImpl.MemoryItem> sensoryMemStateDescriptor =
Expand All @@ -103,6 +110,7 @@ void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext
"shortTermMemory",
TypeInformation.of(String.class),
TypeInformation.of(MemoryObjectImpl.MemoryItem.class));
maybeEnableShortTermMemoryTTL(shortTermMemStateDescriptor, agentPlan);
shortTermMemState = runtimeContext.getMapState(shortTermMemStateDescriptor);

// init sequence number state for per key message ordering
Expand All @@ -121,6 +129,39 @@ void initializeKeyedStates(org.apache.flink.api.common.functions.RuntimeContext
PENDING_INPUT_EVENT_STATE_NAME, TypeInformation.of(Event.class)));
}

/**
* When {@link AgentExecutionOptions#SHORT_TERM_MEMORY_STATE_TTL_MS} is positive, attaches Flink
* {@link StateTtlConfig} to the short-term memory {@link MapStateDescriptor}. Unset, null, or
* non-positive values disable TTL (Flink does not allow zero/negative TTL).
*/
private void maybeEnableShortTermMemoryTTL(
MapStateDescriptor<String, MemoryObjectImpl.MemoryItem> descriptor,
AgentPlan agentPlan) {
Long ttlMs =
agentPlan.getConfig().get(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS);
if (ttlMs == null || ttlMs <= 0) {
return;
}

StateTtlConfig.UpdateType updateType =
agentPlan
.getConfig()
.get(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE);

StateTtlConfig.StateVisibility stateVisibility =
agentPlan
.getConfig()
.get(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY);

StateTtlConfig ttlConfig =
StateTtlConfig.newBuilder(Duration.ofMillis(ttlMs))
.setUpdateType(updateType)
.setStateVisibility(stateVisibility)
.cleanupFullSnapshot()
.build();
descriptor.enableTimeToLive(ttlConfig);
}

/**
* Registers operator-level (non-keyed) state.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.agents.runtime.memory;

import org.apache.flink.agents.api.AgentsExecutionEnvironment;
import org.apache.flink.agents.api.InputEvent;
import org.apache.flink.agents.api.OutputEvent;
import org.apache.flink.agents.api.agents.Agent;
import org.apache.flink.agents.api.agents.AgentExecutionOptions;
import org.apache.flink.agents.api.annotation.Action;
import org.apache.flink.agents.api.context.MemoryObject;
import org.apache.flink.agents.api.context.RunnerContext;
import org.apache.flink.agents.plan.AgentConfiguration;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;

/** Integration test for Short-Term Memory TTL functionality. */
class ShortTermMemoryTTLIntegrationTest {

private static final String MEMORY_KEY = "test_key";

private static final class TestInput {
public String eventKey;
public long sleepMs;

private TestInput() {}

private TestInput(String eventKey, long sleepMs) {
this.eventKey = eventKey;
this.sleepMs = sleepMs;
}
}

public static class TTLTestAgent extends Agent {

@Action(listenEventTypes = {InputEvent.EVENT_TYPE})
public static void input(org.apache.flink.agents.api.Event event, RunnerContext ctx)
throws Exception {
InputEvent inputEvent = (InputEvent) event;
TestInput input = (TestInput) inputEvent.getInput();

MemoryObject shortTermMemory = ctx.getShortTermMemory();
MemoryObject memoryObject = shortTermMemory.get(input.eventKey);

Object existingValue = null;
int currentCount = 0;
if (memoryObject != null && !memoryObject.isNestedObject()) {
existingValue = memoryObject.getValue();
if (existingValue instanceof Integer) {
currentCount = (Integer) existingValue;
} else if (existingValue instanceof Number) {
currentCount = ((Number) existingValue).intValue();
}
}

shortTermMemory.set(input.eventKey, currentCount + 1);
Thread.sleep(input.sleepMs);
ctx.sendEvent(
new OutputEvent(
input.eventKey + "|" + (existingValue == null ? "NEW" : "EXISTING")));
}
}

@Test
void testTTLConfigurationNotApplied() throws Exception {
List<String> results = runScenario(1000L, 0L);

assertEquals(List.of("event1|NEW", "event2|NEW", "event1|EXISTING"), results);
}

@Test
void testTTLConfigurationApplied() throws Exception {
List<String> results = runScenario(1000L, 2000L);

assertEquals(List.of("event1|NEW", "event2|NEW", "event1|NEW"), results);
}

private static List<String> runScenario(long ttlMs, long sleepMs) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);

AgentsExecutionEnvironment agentEnv =
AgentsExecutionEnvironment.getExecutionEnvironment(env);
AgentConfiguration agentsConfig = (AgentConfiguration) agentEnv.getConfig();
agentsConfig.set(AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_MS, ttlMs);
agentsConfig.set(
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_UPDATE_TYPE,
StateTtlConfig.UpdateType.OnCreateAndWrite);
agentsConfig.set(
AgentExecutionOptions.SHORT_TERM_MEMORY_STATE_TTL_VISIBILITY,
StateTtlConfig.StateVisibility.NeverReturnExpired);

List<TestInput> testData = new ArrayList<>();
testData.add(new TestInput("event1", sleepMs));
testData.add(new TestInput("event2", sleepMs));
testData.add(new TestInput("event1", sleepMs));

DataStream<TestInput> inputStream = env.fromCollection(testData);
DataStream<Object> outputStream =
agentEnv.fromDataStream(inputStream, x -> MEMORY_KEY)
.apply(new TTLTestAgent())
.toDataStream();

List<String> results = new ArrayList<>();
outputStream.map(Object::toString).executeAndCollect().forEachRemaining(results::add);
return results;
}
}
Loading