Skip to content

Commit a625f6c

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor:Refactor Runner to implement a Builder rather than a growing number of constructors
PiperOrigin-RevId: 839375782
1 parent 44d6a21 commit a625f6c

File tree

2 files changed

+137
-36
lines changed

2 files changed

+137
-36
lines changed

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,21 @@
2424
import com.google.adk.agents.LlmAgent;
2525
import com.google.adk.agents.RunConfig;
2626
import com.google.adk.artifacts.BaseArtifactService;
27+
import com.google.adk.artifacts.InMemoryArtifactService;
2728
import com.google.adk.events.Event;
2829
import com.google.adk.events.EventActions;
2930
import com.google.adk.flows.llmflows.ResumabilityConfig;
3031
import com.google.adk.memory.BaseMemoryService;
3132
import com.google.adk.plugins.BasePlugin;
3233
import com.google.adk.plugins.PluginManager;
3334
import com.google.adk.sessions.BaseSessionService;
35+
import com.google.adk.sessions.InMemorySessionService;
3436
import com.google.adk.sessions.Session;
3537
import com.google.adk.tools.BaseTool;
3638
import com.google.adk.tools.FunctionTool;
3739
import com.google.adk.utils.CollectionUtils;
3840
import com.google.common.collect.ImmutableList;
39-
import com.google.errorprone.annotations.InlineMe;
41+
import com.google.errorprone.annotations.CanIgnoreReturnValue;
4042
import com.google.genai.types.AudioTranscriptionConfig;
4143
import com.google.genai.types.Content;
4244
import com.google.genai.types.Modality;
@@ -67,7 +69,92 @@ public class Runner {
6769
private final PluginManager pluginManager;
6870
private final ResumabilityConfig resumabilityConfig;
6971

70-
/** Creates a new {@code Runner}. */
72+
/** Builder for {@link Runner}. */
73+
public static class Builder {
74+
private BaseAgent agent;
75+
private String appName;
76+
private BaseArtifactService artifactService = new InMemoryArtifactService();
77+
private BaseSessionService sessionService = new InMemorySessionService();
78+
@Nullable private BaseMemoryService memoryService = null;
79+
private List<BasePlugin> plugins = ImmutableList.of();
80+
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();
81+
82+
@CanIgnoreReturnValue
83+
public Builder agent(BaseAgent agent) {
84+
this.agent = agent;
85+
return this;
86+
}
87+
88+
@CanIgnoreReturnValue
89+
public Builder appName(String appName) {
90+
this.appName = appName;
91+
return this;
92+
}
93+
94+
@CanIgnoreReturnValue
95+
public Builder artifactService(BaseArtifactService artifactService) {
96+
this.artifactService = artifactService;
97+
return this;
98+
}
99+
100+
@CanIgnoreReturnValue
101+
public Builder sessionService(BaseSessionService sessionService) {
102+
this.sessionService = sessionService;
103+
return this;
104+
}
105+
106+
@CanIgnoreReturnValue
107+
public Builder memoryService(BaseMemoryService memoryService) {
108+
this.memoryService = memoryService;
109+
return this;
110+
}
111+
112+
@CanIgnoreReturnValue
113+
public Builder plugins(List<BasePlugin> plugins) {
114+
this.plugins = plugins;
115+
return this;
116+
}
117+
118+
@CanIgnoreReturnValue
119+
public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) {
120+
this.resumabilityConfig = resumabilityConfig;
121+
return this;
122+
}
123+
124+
public Runner build() {
125+
if (agent == null) {
126+
throw new IllegalStateException("Agent must be provided.");
127+
}
128+
if (appName == null) {
129+
throw new IllegalStateException("App name must be provided.");
130+
}
131+
if (artifactService == null) {
132+
throw new IllegalStateException("Artifact service must be provided.");
133+
}
134+
if (sessionService == null) {
135+
throw new IllegalStateException("Session service must be provided.");
136+
}
137+
return new Runner(
138+
agent,
139+
appName,
140+
artifactService,
141+
sessionService,
142+
memoryService,
143+
plugins,
144+
resumabilityConfig);
145+
}
146+
}
147+
148+
public static Builder builder() {
149+
return new Builder();
150+
}
151+
152+
/**
153+
* Creates a new {@code Runner}.
154+
*
155+
* @deprecated Use {@link Runner.Builder} instead.
156+
*/
157+
@Deprecated
71158
public Runner(
72159
BaseAgent agent,
73160
String appName,
@@ -84,7 +171,12 @@ public Runner(
84171
new ResumabilityConfig());
85172
}
86173

87-
/** Creates a new {@code Runner} with a list of plugins. */
174+
/**
175+
* Creates a new {@code Runner} with a list of plugins.
176+
*
177+
* @deprecated Use {@link Runner.Builder} instead.
178+
*/
179+
@Deprecated
88180
public Runner(
89181
BaseAgent agent,
90182
String appName,
@@ -102,7 +194,12 @@ public Runner(
102194
new ResumabilityConfig());
103195
}
104196

105-
/** Creates a new {@code Runner} with a list of plugins and resumability config. */
197+
/**
198+
* Creates a new {@code Runner} with a list of plugins and resumability config.
199+
*
200+
* @deprecated Use {@link Runner.Builder} instead.
201+
*/
202+
@Deprecated
106203
public Runner(
107204
BaseAgent agent,
108205
String appName,
@@ -123,10 +220,8 @@ public Runner(
123220
/**
124221
* Creates a new {@code Runner}.
125222
*
126-
* @deprecated Use the constructor with {@code BaseMemoryService} instead even if with a null if
127-
* you don't need the memory service.
223+
* @deprecated Use {@link Runner.Builder} instead.
128224
*/
129-
@InlineMe(replacement = "this(agent, appName, artifactService, sessionService, null)")
130225
@Deprecated
131226
public Runner(
132227
BaseAgent agent,

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@
3333
import com.google.adk.agents.LiveRequestQueue;
3434
import com.google.adk.agents.LlmAgent;
3535
import com.google.adk.agents.RunConfig;
36-
import com.google.adk.artifacts.InMemoryArtifactService;
3736
import com.google.adk.events.Event;
3837
import com.google.adk.flows.llmflows.ResumabilityConfig;
3938
import com.google.adk.models.LlmResponse;
4039
import com.google.adk.plugins.BasePlugin;
41-
import com.google.adk.sessions.InMemorySessionService;
4240
import com.google.adk.sessions.Session;
4341
import com.google.adk.testing.TestLlm;
4442
import com.google.adk.testing.TestUtils;
@@ -76,7 +74,8 @@ public final class RunnerTest {
7674
private final Content pluginContent = createContent("from plugin");
7775
private final TestLlm testLlm = createTestLlm(createLlmResponse(createContent("from llm")));
7876
private final LlmAgent agent = createTestAgentBuilder(testLlm).build();
79-
private final Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
77+
private final Runner runner =
78+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
8079
private final Session session =
8180
runner.sessionService().createSession("test", "user").blockingGet();
8281
private Tracer originalTracer;
@@ -159,7 +158,12 @@ public void beforeRunCallback_multiplePluginsFirstOnly() {
159158
BasePlugin plugin2 = mockPlugin("test2");
160159
when(plugin2.beforeRunCallback(any())).thenReturn(Maybe.empty());
161160

162-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin1, plugin2));
161+
Runner runner =
162+
Runner.builder()
163+
.agent(agent)
164+
.appName("test")
165+
.plugins(ImmutableList.of(plugin1, plugin2))
166+
.build();
163167
Session session = runner.sessionService().createSession("test", "user").blockingGet();
164168
var events =
165169
runner
@@ -271,7 +275,8 @@ public void onModelErrorCallback_success() {
271275
TestLlm failingTestLlm = createTestLlm(Flowable.error(exception));
272276
LlmAgent agent = createTestAgentBuilder(failingTestLlm).build();
273277

274-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
278+
Runner runner =
279+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
275280
Session session = runner.sessionService().createSession("test", "user").blockingGet();
276281
var events =
277282
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
@@ -289,7 +294,8 @@ public void onModelErrorCallback_error() {
289294
TestLlm failingTestLlm = createTestLlm(Flowable.error(exception));
290295
LlmAgent agent = createTestAgentBuilder(failingTestLlm).build();
291296

292-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
297+
Runner runner =
298+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
293299
Session session = runner.sessionService().createSession("test", "user").blockingGet();
294300
runner.runAsync("user", session.id(), createContent("from user")).test().assertError(exception);
295301

@@ -307,7 +313,8 @@ public void beforeToolCallback_success() {
307313
.tools(ImmutableList.of(failingEchoTool))
308314
.build();
309315

310-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
316+
Runner runner =
317+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
311318
Session session = runner.sessionService().createSession("test", "user").blockingGet();
312319
var events =
313320
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
@@ -330,7 +337,8 @@ public void afterToolCallback_success() {
330337
LlmAgent agent =
331338
createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build();
332339

333-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
340+
Runner runner =
341+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
334342
Session session = runner.sessionService().createSession("test", "user").blockingGet();
335343
var events =
336344
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
@@ -355,7 +363,8 @@ public void onToolErrorCallback_success() {
355363
.tools(ImmutableList.of(failingEchoTool))
356364
.build();
357365

358-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
366+
Runner runner =
367+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
359368
Session session = runner.sessionService().createSession("test", "user").blockingGet();
360369
var events =
361370
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
@@ -377,7 +386,8 @@ public void onToolErrorCallback_error() {
377386
.tools(ImmutableList.of(failingEchoTool))
378387
.build();
379388

380-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin));
389+
Runner runner =
390+
Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build();
381391
Session session = runner.sessionService().createSession("test", "user").blockingGet();
382392
runner
383393
.runAsync("user", session.id(), createContent("from user"))
@@ -666,7 +676,7 @@ public void runLive_success() throws Exception {
666676
public void runLive_withToolExecution() throws Exception {
667677
LlmAgent agentWithTool =
668678
createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build();
669-
Runner runnerWithTool = new InMemoryRunner(agentWithTool, "test", ImmutableList.of());
679+
Runner runnerWithTool = Runner.builder().agent(agentWithTool).appName("test").build();
670680
Session sessionWithTool =
671681
runnerWithTool.sessionService().createSession("test", "user").blockingGet();
672682
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
@@ -693,7 +703,7 @@ public void runLive_llmError() throws Exception {
693703
Exception exception = new Exception("LLM test error");
694704
TestLlm failingTestLlm = createTestLlm(Flowable.error(exception));
695705
LlmAgent agent = createTestAgentBuilder(failingTestLlm).build();
696-
Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of());
706+
Runner runner = Runner.builder().agent(agent).appName("test").build();
697707
Session session = runner.sessionService().createSession("test", "user").blockingGet();
698708
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
699709
TestSubscriber<Event> testSubscriber =
@@ -713,7 +723,7 @@ public void runLive_toolError() throws Exception {
713723
.tools(ImmutableList.of(failingEchoTool))
714724
.build();
715725
Runner runnerWithFailingTool =
716-
new InMemoryRunner(agentWithFailingTool, "test", ImmutableList.of());
726+
Runner.builder().agent(agentWithFailingTool).appName("test").build();
717727
Session sessionWithFailingTool =
718728
runnerWithFailingTool.sessionService().createSession("test", "user").blockingGet();
719729
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
@@ -752,14 +762,12 @@ public void resumabilityConfig_isResumable_isTrueInInvocationContext() {
752762
ArgumentCaptor.forClass(InvocationContext.class);
753763
when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty());
754764
Runner runner =
755-
new Runner(
756-
agent,
757-
"test",
758-
new InMemoryArtifactService(),
759-
new InMemorySessionService(),
760-
/* memoryService= */ null,
761-
ImmutableList.of(plugin),
762-
new ResumabilityConfig(true));
765+
Runner.builder()
766+
.agent(agent)
767+
.appName("test")
768+
.plugins(ImmutableList.of(plugin))
769+
.resumabilityConfig(new ResumabilityConfig(true))
770+
.build();
763771
Session session = runner.sessionService().createSession("test", "user").blockingGet();
764772
var unused =
765773
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
@@ -772,14 +780,12 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() {
772780
ArgumentCaptor.forClass(InvocationContext.class);
773781
when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty());
774782
Runner runner =
775-
new Runner(
776-
agent,
777-
"test",
778-
new InMemoryArtifactService(),
779-
new InMemorySessionService(),
780-
/* memoryService= */ null,
781-
ImmutableList.of(plugin),
782-
new ResumabilityConfig(false));
783+
Runner.builder()
784+
.agent(agent)
785+
.appName("test")
786+
.plugins(ImmutableList.of(plugin))
787+
.resumabilityConfig(new ResumabilityConfig(false))
788+
.build();
783789
Session session = runner.sessionService().createSession("test", "user").blockingGet();
784790
var unused =
785791
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();

0 commit comments

Comments
 (0)