Skip to content

Commit 473f9f2

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor:Update BasePlugin to accept LlmRequest.Builder for callbacks instead of immutable LlmRequest
PiperOrigin-RevId: 840731297
1 parent 2906eb5 commit 473f9f2

File tree

9 files changed

+45
-41
lines changed

9 files changed

+45
-41
lines changed

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,18 @@ private Flowable<LlmResponse> callLlm(
204204
.runOnModelErrorCallback(
205205
new CallbackContext(
206206
context, eventForCallbackUsage.actions()),
207-
llmRequest,
207+
llmRequestBuilder,
208208
exception)
209209
.switchIfEmpty(Single.error(exception))
210210
.toFlowable())
211211
.doOnNext(
212212
llmResp -> {
213213
try (Scope innerScope = llmCallSpan.makeCurrent()) {
214214
Telemetry.traceCallLlm(
215-
context, eventForCallbackUsage.id(), llmRequest, llmResp);
215+
context,
216+
eventForCallbackUsage.id(),
217+
llmRequestBuilder.build(),
218+
llmResp);
216219
}
217220
})
218221
.doOnError(
@@ -242,7 +245,7 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
242245
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
243246

244247
Maybe<LlmResponse> pluginResult =
245-
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build());
248+
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder);
246249

247250
LlmAgent agent = (LlmAgent) context.agent();
248251

core/src/main/java/com/google/adk/plugins/BasePlugin.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ public Maybe<Content> afterAgentCallback(BaseAgent agent, CallbackContext callba
122122
* Callback executed before a request is sent to the model.
123123
*
124124
* @param callbackContext The context for the current agent call.
125-
* @param llmRequest The prepared request object to be sent to the model.
125+
* @param llmRequest The mutable request builder, allowing modification of the request before it
126+
* is sent to the model.
126127
* @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally.
127128
*/
128129
public Maybe<LlmResponse> beforeModelCallback(
129-
CallbackContext callbackContext, LlmRequest llmRequest) {
130+
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
130131
return Maybe.empty();
131132
}
132133

@@ -147,13 +148,13 @@ public Maybe<LlmResponse> afterModelCallback(
147148
* Callback executed when a model call encounters an error.
148149
*
149150
* @param callbackContext The context for the current agent call.
150-
* @param llmRequest The request that was sent to the model.
151+
* @param llmRequest The mutable request builder for the request that failed.
151152
* @param error The exception that was raised.
152153
* @return An optional LlmResponse to use instead of propagating the error. Returning Empty to
153154
* allow the original error to be raised.
154155
*/
155156
public Maybe<LlmResponse> onModelErrorCallback(
156-
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
157+
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
157158
return Maybe.empty();
158159
}
159160

core/src/main/java/com/google/adk/plugins/LoggingPlugin.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,15 @@ public Maybe<Content> afterAgentCallback(BaseAgent agent, CallbackContext callba
151151

152152
@Override
153153
public Maybe<LlmResponse> beforeModelCallback(
154-
CallbackContext callbackContext, LlmRequest llmRequest) {
154+
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
155155
return Maybe.fromAction(
156156
() -> {
157+
LlmRequest request = llmRequest.build();
157158
log("🧠 LLM REQUEST");
158-
log(" Model: " + llmRequest.model().orElse("default"));
159+
log(" Model: " + request.model().orElse("default"));
159160
log(" Agent: " + callbackContext.agentName());
160161

161-
llmRequest
162+
request
162163
.getFirstSystemInstruction()
163164
.ifPresent(
164165
sysInstruction -> {
@@ -170,8 +171,8 @@ public Maybe<LlmResponse> beforeModelCallback(
170171
log(" System Instruction: '" + truncatedInstruction + "'");
171172
});
172173

173-
if (!llmRequest.tools().isEmpty()) {
174-
String toolNames = String.join(", ", llmRequest.tools().keySet());
174+
if (!request.tools().isEmpty()) {
175+
String toolNames = String.join(", ", request.tools().keySet());
175176
log(" Available Tools: [" + toolNames + "]");
176177
}
177178
});
@@ -211,7 +212,7 @@ public Maybe<LlmResponse> afterModelCallback(
211212

212213
@Override
213214
public Maybe<LlmResponse> onModelErrorCallback(
214-
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
215+
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
215216
return Maybe.fromAction(
216217
() -> {
217218
log("🧠 LLM ERROR");

core/src/main/java/com/google/adk/plugins/PluginManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public Maybe<Content> runAfterAgentCallback(BaseAgent agent, CallbackContext cal
127127
}
128128

129129
public Maybe<LlmResponse> runBeforeModelCallback(
130-
CallbackContext callbackContext, LlmRequest llmRequest) {
130+
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
131131
return runMaybeCallbacks(
132132
plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback");
133133
}
@@ -139,7 +139,7 @@ public Maybe<LlmResponse> runAfterModelCallback(
139139
}
140140

141141
public Maybe<LlmResponse> runOnModelErrorCallback(
142-
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
142+
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
143143
return runMaybeCallbacks(
144144
plugin -> plugin.onModelErrorCallback(callbackContext, llmRequest, error),
145145
"onModelErrorCallback");

core/src/test/java/com/google/adk/plugins/BasePluginTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ private static class TestPlugin extends BasePlugin {
4343
private final CallbackContext callbackContext = Mockito.mock(CallbackContext.class);
4444
private final Content content = Content.builder().build();
4545
private final Event event = Mockito.mock(Event.class);
46-
private final LlmRequest llmRequest = LlmRequest.builder().build();
46+
private final LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
4747
private final LlmResponse llmResponse = LlmResponse.builder().build();
4848
private final ToolContext toolContext = Mockito.mock(ToolContext.class);
4949

@@ -79,7 +79,7 @@ public void afterAgentCallback_returnsEmptyMaybe() {
7979

8080
@Test
8181
public void beforeModelCallback_returnsEmptyMaybe() {
82-
plugin.beforeModelCallback(callbackContext, llmRequest).test().assertResult();
82+
plugin.beforeModelCallback(callbackContext, llmRequestBuilder).test().assertResult();
8383
}
8484

8585
@Test
@@ -90,7 +90,7 @@ public void afterModelCallback_returnsEmptyMaybe() {
9090
@Test
9191
public void onModelErrorCallback_returnsEmptyMaybe() {
9292
plugin
93-
.onModelErrorCallback(callbackContext, llmRequest, new RuntimeException())
93+
.onModelErrorCallback(callbackContext, llmRequestBuilder, new RuntimeException())
9494
.test()
9595
.assertResult();
9696
}

core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public class LoggingPluginTest {
6969
.actions(EventActions.builder().build())
7070
.longRunningToolIds(Optional.empty())
7171
.build();
72-
private final LlmRequest llmRequest =
73-
LlmRequest.builder().model("default").contents(ImmutableList.of()).build();
72+
private final LlmRequest.Builder llmRequestBuilder =
73+
LlmRequest.builder().model("default").contents(ImmutableList.of());
7474
private final LlmResponse llmResponse = LlmResponse.builder().build();
7575
private final ImmutableMap<String, Object> toolArgs = ImmutableMap.of();
7676
private final ImmutableMap<String, Object> toolResult = ImmutableMap.of();
@@ -175,7 +175,10 @@ public void afterAgentCallback_runsWithoutError() {
175175

176176
@Test
177177
public void beforeModelCallback_runsWithoutError() {
178-
loggingPlugin.beforeModelCallback(mockCallbackContext, llmRequest).test().assertComplete();
178+
loggingPlugin
179+
.beforeModelCallback(mockCallbackContext, llmRequestBuilder)
180+
.test()
181+
.assertComplete();
179182
}
180183

181184
@Test
@@ -184,8 +187,7 @@ public void beforeModelCallback_longSystemInstruction() {
184187
.beforeModelCallback(
185188
mockCallbackContext,
186189
LlmRequest.builder()
187-
.appendInstructions(ImmutableList.of("all work and no play".repeat(1000)))
188-
.build())
190+
.appendInstructions(ImmutableList.of("all work and no play".repeat(1000))))
189191
.test()
190192
.assertComplete();
191193
}
@@ -194,8 +196,7 @@ public void beforeModelCallback_longSystemInstruction() {
194196
public void beforeModelCallback_tools() {
195197
loggingPlugin
196198
.beforeModelCallback(
197-
mockCallbackContext,
198-
LlmRequest.builder().appendTools(ImmutableList.of(mockTool)).build())
199+
mockCallbackContext, LlmRequest.builder().appendTools(ImmutableList.of(mockTool)))
199200
.test()
200201
.assertComplete();
201202
}
@@ -231,7 +232,7 @@ public void afterModelCallback_usageMetadata() {
231232
@Test
232233
public void onModelErrorCallback_runsWithoutError() {
233234
loggingPlugin
234-
.onModelErrorCallback(mockCallbackContext, llmRequest, throwable)
235+
.onModelErrorCallback(mockCallbackContext, llmRequestBuilder, throwable)
235236
.test()
236237
.assertComplete();
237238
}

core/src/test/java/com/google/adk/plugins/PluginManagerTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,18 +236,18 @@ public void runAfterAgentCallback_singlePlugin() {
236236
@Test
237237
public void runBeforeModelCallback_singlePlugin() {
238238
CallbackContext mockCallbackContext = mock(CallbackContext.class);
239-
LlmRequest llmRequest = LlmRequest.builder().build();
239+
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
240240
LlmResponse llmResponse = LlmResponse.builder().build();
241241

242242
when(plugin1.beforeModelCallback(any(), any())).thenReturn(Maybe.just(llmResponse));
243243
pluginManager.registerPlugin(plugin1);
244244

245245
pluginManager
246-
.runBeforeModelCallback(mockCallbackContext, llmRequest)
246+
.runBeforeModelCallback(mockCallbackContext, llmRequestBuilder)
247247
.test()
248248
.assertResult(llmResponse);
249249

250-
verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequest);
250+
verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequestBuilder);
251251
}
252252

253253
@Test
@@ -269,19 +269,19 @@ public void runAfterModelCallback_singlePlugin() {
269269
@Test
270270
public void runOnModelErrorCallback_singlePlugin() {
271271
CallbackContext mockCallbackContext = mock(CallbackContext.class);
272-
LlmRequest llmRequest = LlmRequest.builder().build();
272+
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
273273
Throwable mockThrowable = mock(Throwable.class);
274274
LlmResponse llmResponse = LlmResponse.builder().build();
275275

276276
when(plugin1.onModelErrorCallback(any(), any(), any())).thenReturn(Maybe.just(llmResponse));
277277
pluginManager.registerPlugin(plugin1);
278278

279279
pluginManager
280-
.runOnModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable)
280+
.runOnModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable)
281281
.test()
282282
.assertResult(llmResponse);
283283

284-
verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable);
284+
verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable);
285285
}
286286

287287
@Test

dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {
7171

7272
@Override
7373
public Maybe<LlmResponse> beforeModelCallback(
74-
CallbackContext callbackContext, LlmRequest llmRequest) {
74+
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
7575
if (!isReplayModeOn(callbackContext)) {
7676
return Maybe.empty();
7777
}
@@ -261,7 +261,7 @@ private Recording getNextRecordingForAgent(InvocationReplayState state, String a
261261
}
262262

263263
private LlmRecording verifyAndGetNextLlmRecordingForAgent(
264-
InvocationReplayState state, String agentName, LlmRequest llmRequest) {
264+
InvocationReplayState state, String agentName, LlmRequest.Builder llmRequest) {
265265
int currentAgentIndex = state.getAgentReplayIndex(agentName);
266266
Recording expectedRecording = getNextRecordingForAgent(state, agentName);
267267

@@ -278,7 +278,7 @@ private LlmRecording verifyAndGetNextLlmRecordingForAgent(
278278
// Strict verification of LLM request
279279
if (llmRecording.llmRequest().isPresent()) {
280280
verifyLlmRequestMatch(
281-
llmRecording.llmRequest().get(), llmRequest, agentName, currentAgentIndex);
281+
llmRecording.llmRequest().get(), llmRequest.build(), agentName, currentAgentIndex);
282282
}
283283

284284
return llmRecording;

dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,15 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws
107107
when(callbackContext.invocationId()).thenReturn("test-invocation");
108108
when(callbackContext.agentName()).thenReturn("test_agent");
109109

110-
LlmRequest request =
110+
var request =
111111
LlmRequest.builder()
112112
.model("gemini-2.0-flash")
113113
.contents(
114114
ImmutableList.of(
115115
Content.builder()
116116
.role("user")
117117
.parts(Part.builder().text("Hello").build())
118-
.build()))
119-
.build();
118+
.build()));
120119

121120
// Step 4: Verify expected response is returned
122121
var result = plugin.beforeModelCallback(callbackContext, request).blockingGet();
@@ -162,16 +161,15 @@ void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception {
162161
when(callbackContext.invocationId()).thenReturn("test-invocation");
163162
when(callbackContext.agentName()).thenReturn("test_agent");
164163

165-
LlmRequest request =
164+
var request =
166165
LlmRequest.builder()
167166
.model("gemini-2.0-flash") // Different model
168167
.contents(
169168
ImmutableList.of(
170169
Content.builder()
171170
.role("user")
172171
.parts(Part.builder().text("Hello").build())
173-
.build()))
174-
.build();
172+
.build()));
175173

176174
// Step 4: Verify result is empty
177175
var result = plugin.beforeModelCallback(callbackContext, request).blockingGet();

0 commit comments

Comments
 (0)