Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,21 @@
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.flows.llmflows.ResumabilityConfig;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.plugins.BasePlugin;
import com.google.adk.plugins.PluginManager;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.Session;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.FunctionTool;
import com.google.adk.utils.CollectionUtils;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.InlineMe;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.genai.types.AudioTranscriptionConfig;
import com.google.genai.types.Content;
import com.google.genai.types.Modality;
Expand Down Expand Up @@ -67,7 +69,92 @@ public class Runner {
private final PluginManager pluginManager;
private final ResumabilityConfig resumabilityConfig;

/** Creates a new {@code Runner}. */
/** Builder for {@link Runner}. */
public static class Builder {
private BaseAgent agent;
private String appName;
private BaseArtifactService artifactService = new InMemoryArtifactService();
private BaseSessionService sessionService = new InMemorySessionService();
@Nullable private BaseMemoryService memoryService = null;
private List<BasePlugin> plugins = ImmutableList.of();
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();

@CanIgnoreReturnValue
public Builder agent(BaseAgent agent) {
this.agent = agent;
return this;
}

@CanIgnoreReturnValue
public Builder appName(String appName) {
this.appName = appName;
return this;
}

@CanIgnoreReturnValue
public Builder artifactService(BaseArtifactService artifactService) {
this.artifactService = artifactService;
return this;
}

@CanIgnoreReturnValue
public Builder sessionService(BaseSessionService sessionService) {
this.sessionService = sessionService;
return this;
}

@CanIgnoreReturnValue
public Builder memoryService(BaseMemoryService memoryService) {
this.memoryService = memoryService;
return this;
}

@CanIgnoreReturnValue
public Builder plugins(List<BasePlugin> plugins) {
this.plugins = plugins;
return this;
}

@CanIgnoreReturnValue
public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) {
this.resumabilityConfig = resumabilityConfig;
return this;
}

public Runner build() {
if (agent == null) {
throw new IllegalStateException("Agent must be provided.");
}
if (appName == null) {
throw new IllegalStateException("App name must be provided.");
}
if (artifactService == null) {
throw new IllegalStateException("Artifact service must be provided.");
}
if (sessionService == null) {
throw new IllegalStateException("Session service must be provided.");
}
return new Runner(
agent,
appName,
artifactService,
sessionService,
memoryService,
plugins,
resumabilityConfig);
}
}

public static Builder builder() {
return new Builder();
}

/**
* Creates a new {@code Runner}.
*
* @deprecated Use {@link Runner.Builder} instead.
*/
@Deprecated
public Runner(
BaseAgent agent,
String appName,
Expand All @@ -84,7 +171,12 @@ public Runner(
new ResumabilityConfig());
}

/** Creates a new {@code Runner} with a list of plugins. */
/**
* Creates a new {@code Runner} with a list of plugins.
*
* @deprecated Use {@link Runner.Builder} instead.
*/
@Deprecated
public Runner(
BaseAgent agent,
String appName,
Expand All @@ -102,7 +194,12 @@ public Runner(
new ResumabilityConfig());
}

/** Creates a new {@code Runner} with a list of plugins and resumability config. */
/**
* Creates a new {@code Runner} with a list of plugins and resumability config.
*
* @deprecated Use {@link Runner.Builder} instead.
*/
@Deprecated
public Runner(
BaseAgent agent,
String appName,
Expand All @@ -123,10 +220,8 @@ public Runner(
/**
* Creates a new {@code Runner}.
*
* @deprecated Use the constructor with {@code BaseMemoryService} instead even if with a null if
* you don't need the memory service.
* @deprecated Use {@link Runner.Builder} instead.
*/
@InlineMe(replacement = "this(agent, appName, artifactService, sessionService, null)")
@Deprecated
public Runner(
BaseAgent agent,
Expand Down
64 changes: 35 additions & 29 deletions core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@
import com.google.adk.agents.LiveRequestQueue;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig;
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.events.Event;
import com.google.adk.flows.llmflows.ResumabilityConfig;
import com.google.adk.models.LlmResponse;
import com.google.adk.plugins.BasePlugin;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.Session;
import com.google.adk.testing.TestLlm;
import com.google.adk.testing.TestUtils;
Expand Down Expand Up @@ -76,7 +74,8 @@ public final class RunnerTest {
private final Content pluginContent = createContent("from plugin");
private final TestLlm testLlm = createTestLlm(createLlmResponse(createContent("from llm")));
private final LlmAgent agent = createTestAgentBuilder(testLlm).build();
private final Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
private final Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
private final Session session =
runner.sessionService().createSession("test", "user").blockingGet();
private Tracer originalTracer;
Expand Down Expand Up @@ -159,7 +158,12 @@ public void beforeRunCallback_multiplePluginsFirstOnly() {
BasePlugin plugin2 = mockPlugin("test2");
when(plugin2.beforeRunCallback(any())).thenReturn(Maybe.empty());

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin1, plugin2));
Runner runner =
Runner.builder()
.agent(agent)
.appName("test")
.plugins(ImmutableList.of(plugin1, plugin2))
.build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var events =
runner
Expand Down Expand Up @@ -271,7 +275,8 @@ public void onModelErrorCallback_success() {
TestLlm failingTestLlm = createTestLlm(Flowable.error(exception));
LlmAgent agent = createTestAgentBuilder(failingTestLlm).build();

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var events =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
Expand All @@ -289,7 +294,8 @@ public void onModelErrorCallback_error() {
TestLlm failingTestLlm = createTestLlm(Flowable.error(exception));
LlmAgent agent = createTestAgentBuilder(failingTestLlm).build();

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
runner.runAsync("user", session.id(), createContent("from user")).test().assertError(exception);

Expand All @@ -307,7 +313,8 @@ public void beforeToolCallback_success() {
.tools(ImmutableList.of(failingEchoTool))
.build();

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var events =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
Expand All @@ -330,7 +337,8 @@ public void afterToolCallback_success() {
LlmAgent agent =
createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build();

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var events =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
Expand All @@ -355,7 +363,8 @@ public void onToolErrorCallback_success() {
.tools(ImmutableList.of(failingEchoTool))
.build();

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var events =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
Expand All @@ -377,7 +386,8 @@ public void onToolErrorCallback_error() {
.tools(ImmutableList.of(failingEchoTool))
.build();

Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
Runner runner =
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
runner
.runAsync("user", session.id(), createContent("from user"))
Expand Down Expand Up @@ -666,7 +676,7 @@ public void runLive_success() throws Exception {
public void runLive_withToolExecution() throws Exception {
LlmAgent agentWithTool =
createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build();
Runner runnerWithTool = new InMemoryRunner(agentWithTool, "test", ImmutableList.of());
Runner runnerWithTool = Runner.builder().agent(agentWithTool).appName("test").build();
Session sessionWithTool =
runnerWithTool.sessionService().createSession("test", "user").blockingGet();
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
Expand All @@ -693,7 +703,7 @@ public void runLive_llmError() throws Exception {
Exception exception = new Exception("LLM test error");
TestLlm failingTestLlm = createTestLlm(Flowable.error(exception));
LlmAgent agent = createTestAgentBuilder(failingTestLlm).build();
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of());
Runner runner = Runner.builder().agent(agent).appName("test").build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
TestSubscriber<Event> testSubscriber =
Expand All @@ -713,7 +723,7 @@ public void runLive_toolError() throws Exception {
.tools(ImmutableList.of(failingEchoTool))
.build();
Runner runnerWithFailingTool =
new InMemoryRunner(agentWithFailingTool, "test", ImmutableList.of());
Runner.builder().agent(agentWithFailingTool).appName("test").build();
Session sessionWithFailingTool =
runnerWithFailingTool.sessionService().createSession("test", "user").blockingGet();
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
Expand Down Expand Up @@ -752,14 +762,12 @@ public void resumabilityConfig_isResumable_isTrueInInvocationContext() {
ArgumentCaptor.forClass(InvocationContext.class);
when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty());
Runner runner =
new Runner(
agent,
"test",
new InMemoryArtifactService(),
new InMemorySessionService(),
/* memoryService= */ null,
ImmutableList.of(plugin),
new ResumabilityConfig(true));
Runner.builder()
.agent(agent)
.appName("test")
.plugins(ImmutableList.of(plugin))
.resumabilityConfig(new ResumabilityConfig(true))
.build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var unused =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
Expand All @@ -772,14 +780,12 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() {
ArgumentCaptor.forClass(InvocationContext.class);
when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty());
Runner runner =
new Runner(
agent,
"test",
new InMemoryArtifactService(),
new InMemorySessionService(),
/* memoryService= */ null,
ImmutableList.of(plugin),
new ResumabilityConfig(false));
Runner.builder()
.agent(agent)
.appName("test")
.plugins(ImmutableList.of(plugin))
.resumabilityConfig(new ResumabilityConfig(false))
.build();
Session session = runner.sessionService().createSession("test", "user").blockingGet();
var unused =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
Expand Down