From 089a9cb8102770315a71bd7025adfc48e061ff3e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 10 Dec 2025 06:23:30 -0800 Subject: [PATCH] refactor:Update BasePlugin to accept LlmRequest.Builder for callbacks instead of immutable LlmRequest PiperOrigin-RevId: 842696959 --- .../google/adk/flows/llmflows/BaseLlmFlow.java | 9 ++++++--- .../java/com/google/adk/plugins/BasePlugin.java | 9 +++++---- .../com/google/adk/plugins/LoggingPlugin.java | 13 +++++++------ .../com/google/adk/plugins/PluginManager.java | 4 ++-- .../com/google/adk/plugins/BasePluginTest.java | 6 +++--- .../google/adk/plugins/LoggingPluginTest.java | 17 +++++++++-------- .../google/adk/plugins/PluginManagerTest.java | 12 ++++++------ .../com/google/adk/plugins/ReplayPlugin.java | 6 +++--- .../google/adk/plugins/ReplayPluginTest.java | 10 ++++------ 9 files changed, 45 insertions(+), 41 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 87e9df8d..ab3a5974 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -204,7 +204,7 @@ private Flowable callLlm( .runOnModelErrorCallback( new CallbackContext( context, eventForCallbackUsage.actions()), - llmRequest, + llmRequestBuilder, exception) .switchIfEmpty(Single.error(exception)) .toFlowable()) @@ -212,7 +212,10 @@ private Flowable callLlm( llmResp -> { try (Scope innerScope = llmCallSpan.makeCurrent()) { Telemetry.traceCallLlm( - context, eventForCallbackUsage.id(), llmRequest, llmResp); + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp); } }) .doOnError( @@ -242,7 +245,7 @@ private Single> handleBeforeModelCallback( CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); Maybe pluginResult = - context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build()); + context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder); LlmAgent agent = (LlmAgent) context.agent(); diff --git a/core/src/main/java/com/google/adk/plugins/BasePlugin.java b/core/src/main/java/com/google/adk/plugins/BasePlugin.java index 7dd22c42..e9ee783b 100644 --- a/core/src/main/java/com/google/adk/plugins/BasePlugin.java +++ b/core/src/main/java/com/google/adk/plugins/BasePlugin.java @@ -122,11 +122,12 @@ public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callba * Callback executed before a request is sent to the model. * * @param callbackContext The context for the current agent call. - * @param llmRequest The prepared request object to be sent to the model. + * @param llmRequest The mutable request builder, allowing modification of the request before it + * is sent to the model. * @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally. */ public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.empty(); } @@ -147,13 +148,13 @@ public Maybe afterModelCallback( * Callback executed when a model call encounters an error. * * @param callbackContext The context for the current agent call. - * @param llmRequest The request that was sent to the model. + * @param llmRequest The mutable request builder for the request that failed. * @param error The exception that was raised. * @return An optional LlmResponse to use instead of propagating the error. Returning Empty to * allow the original error to be raised. */ public Maybe onModelErrorCallback( - CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java b/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java index d0ce3ab4..71c6a4bf 100644 --- a/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/LoggingPlugin.java @@ -151,14 +151,15 @@ public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callba @Override public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return Maybe.fromAction( () -> { + LlmRequest request = llmRequest.build(); log("🧠 LLM REQUEST"); - log(" Model: " + llmRequest.model().orElse("default")); + log(" Model: " + request.model().orElse("default")); log(" Agent: " + callbackContext.agentName()); - llmRequest + request .getFirstSystemInstruction() .ifPresent( sysInstruction -> { @@ -170,8 +171,8 @@ public Maybe beforeModelCallback( log(" System Instruction: '" + truncatedInstruction + "'"); }); - if (!llmRequest.tools().isEmpty()) { - String toolNames = String.join(", ", llmRequest.tools().keySet()); + if (!request.tools().isEmpty()) { + String toolNames = String.join(", ", request.tools().keySet()); log(" Available Tools: [" + toolNames + "]"); } }); @@ -211,7 +212,7 @@ public Maybe afterModelCallback( @Override public Maybe onModelErrorCallback( - CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return Maybe.fromAction( () -> { log("🧠 LLM ERROR"); diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 135168e9..e95a5c78 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -127,7 +127,7 @@ public Maybe runAfterAgentCallback(BaseAgent agent, CallbackContext cal } public Maybe runBeforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return runMaybeCallbacks( plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback"); } @@ -139,7 +139,7 @@ public Maybe runAfterModelCallback( } public Maybe runOnModelErrorCallback( - CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return runMaybeCallbacks( plugin -> plugin.onModelErrorCallback(callbackContext, llmRequest, error), "onModelErrorCallback"); diff --git a/core/src/test/java/com/google/adk/plugins/BasePluginTest.java b/core/src/test/java/com/google/adk/plugins/BasePluginTest.java index 9a4a243c..58175a2f 100644 --- a/core/src/test/java/com/google/adk/plugins/BasePluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/BasePluginTest.java @@ -43,7 +43,7 @@ private static class TestPlugin extends BasePlugin { private final CallbackContext callbackContext = Mockito.mock(CallbackContext.class); private final Content content = Content.builder().build(); private final Event event = Mockito.mock(Event.class); - private final LlmRequest llmRequest = LlmRequest.builder().build(); + private final LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); private final LlmResponse llmResponse = LlmResponse.builder().build(); private final ToolContext toolContext = Mockito.mock(ToolContext.class); @@ -79,7 +79,7 @@ public void afterAgentCallback_returnsEmptyMaybe() { @Test public void beforeModelCallback_returnsEmptyMaybe() { - plugin.beforeModelCallback(callbackContext, llmRequest).test().assertResult(); + plugin.beforeModelCallback(callbackContext, llmRequestBuilder).test().assertResult(); } @Test @@ -90,7 +90,7 @@ public void afterModelCallback_returnsEmptyMaybe() { @Test public void onModelErrorCallback_returnsEmptyMaybe() { plugin - .onModelErrorCallback(callbackContext, llmRequest, new RuntimeException()) + .onModelErrorCallback(callbackContext, llmRequestBuilder, new RuntimeException()) .test() .assertResult(); } diff --git a/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java b/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java index 4c90c11b..52230f07 100644 --- a/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java @@ -69,8 +69,8 @@ public class LoggingPluginTest { .actions(EventActions.builder().build()) .longRunningToolIds(Optional.empty()) .build(); - private final LlmRequest llmRequest = - LlmRequest.builder().model("default").contents(ImmutableList.of()).build(); + private final LlmRequest.Builder llmRequestBuilder = + LlmRequest.builder().model("default").contents(ImmutableList.of()); private final LlmResponse llmResponse = LlmResponse.builder().build(); private final ImmutableMap toolArgs = ImmutableMap.of(); private final ImmutableMap toolResult = ImmutableMap.of(); @@ -175,7 +175,10 @@ public void afterAgentCallback_runsWithoutError() { @Test public void beforeModelCallback_runsWithoutError() { - loggingPlugin.beforeModelCallback(mockCallbackContext, llmRequest).test().assertComplete(); + loggingPlugin + .beforeModelCallback(mockCallbackContext, llmRequestBuilder) + .test() + .assertComplete(); } @Test @@ -184,8 +187,7 @@ public void beforeModelCallback_longSystemInstruction() { .beforeModelCallback( mockCallbackContext, LlmRequest.builder() - .appendInstructions(ImmutableList.of("all work and no play".repeat(1000))) - .build()) + .appendInstructions(ImmutableList.of("all work and no play".repeat(1000)))) .test() .assertComplete(); } @@ -194,8 +196,7 @@ public void beforeModelCallback_longSystemInstruction() { public void beforeModelCallback_tools() { loggingPlugin .beforeModelCallback( - mockCallbackContext, - LlmRequest.builder().appendTools(ImmutableList.of(mockTool)).build()) + mockCallbackContext, LlmRequest.builder().appendTools(ImmutableList.of(mockTool))) .test() .assertComplete(); } @@ -231,7 +232,7 @@ public void afterModelCallback_usageMetadata() { @Test public void onModelErrorCallback_runsWithoutError() { loggingPlugin - .onModelErrorCallback(mockCallbackContext, llmRequest, throwable) + .onModelErrorCallback(mockCallbackContext, llmRequestBuilder, throwable) .test() .assertComplete(); } diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index 4737d6cd..fe115e46 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -236,18 +236,18 @@ public void runAfterAgentCallback_singlePlugin() { @Test public void runBeforeModelCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); - LlmRequest llmRequest = LlmRequest.builder().build(); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); LlmResponse llmResponse = LlmResponse.builder().build(); when(plugin1.beforeModelCallback(any(), any())).thenReturn(Maybe.just(llmResponse)); pluginManager.registerPlugin(plugin1); pluginManager - .runBeforeModelCallback(mockCallbackContext, llmRequest) + .runBeforeModelCallback(mockCallbackContext, llmRequestBuilder) .test() .assertResult(llmResponse); - verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequest); + verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequestBuilder); } @Test @@ -269,7 +269,7 @@ public void runAfterModelCallback_singlePlugin() { @Test public void runOnModelErrorCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); - LlmRequest llmRequest = LlmRequest.builder().build(); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); Throwable mockThrowable = mock(Throwable.class); LlmResponse llmResponse = LlmResponse.builder().build(); @@ -277,11 +277,11 @@ public void runOnModelErrorCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runOnModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable) + .runOnModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable) .test() .assertResult(llmResponse); - verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable); + verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable); } @Test diff --git a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java index 1164709a..d3224e30 100644 --- a/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java +++ b/dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java @@ -71,7 +71,7 @@ public Maybe beforeRunCallback(InvocationContext invocationContext) { @Override public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest llmRequest) { + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { if (!isReplayModeOn(callbackContext)) { return Maybe.empty(); } @@ -261,7 +261,7 @@ private Recording getNextRecordingForAgent(InvocationReplayState state, String a } private LlmRecording verifyAndGetNextLlmRecordingForAgent( - InvocationReplayState state, String agentName, LlmRequest llmRequest) { + InvocationReplayState state, String agentName, LlmRequest.Builder llmRequest) { int currentAgentIndex = state.getAgentReplayIndex(agentName); Recording expectedRecording = getNextRecordingForAgent(state, agentName); @@ -278,7 +278,7 @@ private LlmRecording verifyAndGetNextLlmRecordingForAgent( // Strict verification of LLM request if (llmRecording.llmRequest().isPresent()) { verifyLlmRequestMatch( - llmRecording.llmRequest().get(), llmRequest, agentName, currentAgentIndex); + llmRecording.llmRequest().get(), llmRequest.build(), agentName, currentAgentIndex); } return llmRecording; diff --git a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java index fe4d2a0b..f29298bc 100644 --- a/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java +++ b/dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java @@ -107,7 +107,7 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws when(callbackContext.invocationId()).thenReturn("test-invocation"); when(callbackContext.agentName()).thenReturn("test_agent"); - LlmRequest request = + var request = LlmRequest.builder() .model("gemini-2.0-flash") .contents( @@ -115,8 +115,7 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws Content.builder() .role("user") .parts(Part.builder().text("Hello").build()) - .build())) - .build(); + .build())); // Step 4: Verify expected response is returned var result = plugin.beforeModelCallback(callbackContext, request).blockingGet(); @@ -162,7 +161,7 @@ void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception { when(callbackContext.invocationId()).thenReturn("test-invocation"); when(callbackContext.agentName()).thenReturn("test_agent"); - LlmRequest request = + var request = LlmRequest.builder() .model("gemini-2.0-flash") // Different model .contents( @@ -170,8 +169,7 @@ void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception { Content.builder() .role("user") .parts(Part.builder().text("Hello").build()) - .build())) - .build(); + .build())); // Step 4: Verify result is empty var result = plugin.beforeModelCallback(callbackContext, request).blockingGet();