Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,18 @@ private Flowable<LlmResponse> callLlm(
.runOnModelErrorCallback(
new CallbackContext(
context, eventForCallbackUsage.actions()),
llmRequest,
llmRequestBuilder,
exception)
.switchIfEmpty(Single.error(exception))
.toFlowable())
.doOnNext(
llmResp -> {
try (Scope innerScope = llmCallSpan.makeCurrent()) {
Telemetry.traceCallLlm(
context, eventForCallbackUsage.id(), llmRequest, llmResp);
context,
eventForCallbackUsage.id(),
llmRequestBuilder.build(),
llmResp);
}
})
.doOnError(
Expand Down Expand Up @@ -242,7 +245,7 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());

Maybe<LlmResponse> pluginResult =
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build());
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder);

LlmAgent agent = (LlmAgent) context.agent();

Expand Down
9 changes: 5 additions & 4 deletions core/src/main/java/com/google/adk/plugins/BasePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ public Maybe<Content> afterAgentCallback(BaseAgent agent, CallbackContext callba
* Callback executed before a request is sent to the model.
*
* @param callbackContext The context for the current agent call.
* @param llmRequest The prepared request object to be sent to the model.
* @param llmRequest The mutable request builder, allowing modification of the request before it
* is sent to the model.
* @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally.
*/
public Maybe<LlmResponse> beforeModelCallback(
CallbackContext callbackContext, LlmRequest llmRequest) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
return Maybe.empty();
}

Expand All @@ -147,13 +148,13 @@ public Maybe<LlmResponse> afterModelCallback(
* Callback executed when a model call encounters an error.
*
* @param callbackContext The context for the current agent call.
* @param llmRequest The request that was sent to the model.
* @param llmRequest The mutable request builder for the request that failed.
* @param error The exception that was raised.
* @return An optional LlmResponse to use instead of propagating the error. Returning Empty to
* allow the original error to be raised.
*/
public Maybe<LlmResponse> onModelErrorCallback(
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
return Maybe.empty();
}

Expand Down
13 changes: 7 additions & 6 deletions core/src/main/java/com/google/adk/plugins/LoggingPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,15 @@ public Maybe<Content> afterAgentCallback(BaseAgent agent, CallbackContext callba

@Override
public Maybe<LlmResponse> beforeModelCallback(
CallbackContext callbackContext, LlmRequest llmRequest) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
return Maybe.fromAction(
() -> {
LlmRequest request = llmRequest.build();
log("🧠 LLM REQUEST");
log(" Model: " + llmRequest.model().orElse("default"));
log(" Model: " + request.model().orElse("default"));
log(" Agent: " + callbackContext.agentName());

llmRequest
request
.getFirstSystemInstruction()
.ifPresent(
sysInstruction -> {
Expand All @@ -170,8 +171,8 @@ public Maybe<LlmResponse> beforeModelCallback(
log(" System Instruction: '" + truncatedInstruction + "'");
});

if (!llmRequest.tools().isEmpty()) {
String toolNames = String.join(", ", llmRequest.tools().keySet());
if (!request.tools().isEmpty()) {
String toolNames = String.join(", ", request.tools().keySet());
log(" Available Tools: [" + toolNames + "]");
}
});
Expand Down Expand Up @@ -211,7 +212,7 @@ public Maybe<LlmResponse> afterModelCallback(

@Override
public Maybe<LlmResponse> onModelErrorCallback(
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
return Maybe.fromAction(
() -> {
log("🧠 LLM ERROR");
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/com/google/adk/plugins/PluginManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public Maybe<Content> runAfterAgentCallback(BaseAgent agent, CallbackContext cal
}

public Maybe<LlmResponse> runBeforeModelCallback(
CallbackContext callbackContext, LlmRequest llmRequest) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
return runMaybeCallbacks(
plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback");
}
Expand All @@ -139,7 +139,7 @@ public Maybe<LlmResponse> runAfterModelCallback(
}

public Maybe<LlmResponse> runOnModelErrorCallback(
CallbackContext callbackContext, LlmRequest llmRequest, Throwable error) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) {
return runMaybeCallbacks(
plugin -> plugin.onModelErrorCallback(callbackContext, llmRequest, error),
"onModelErrorCallback");
Expand Down
6 changes: 3 additions & 3 deletions core/src/test/java/com/google/adk/plugins/BasePluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private static class TestPlugin extends BasePlugin {
private final CallbackContext callbackContext = Mockito.mock(CallbackContext.class);
private final Content content = Content.builder().build();
private final Event event = Mockito.mock(Event.class);
private final LlmRequest llmRequest = LlmRequest.builder().build();
private final LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
private final LlmResponse llmResponse = LlmResponse.builder().build();
private final ToolContext toolContext = Mockito.mock(ToolContext.class);

Expand Down Expand Up @@ -79,7 +79,7 @@ public void afterAgentCallback_returnsEmptyMaybe() {

@Test
public void beforeModelCallback_returnsEmptyMaybe() {
plugin.beforeModelCallback(callbackContext, llmRequest).test().assertResult();
plugin.beforeModelCallback(callbackContext, llmRequestBuilder).test().assertResult();
}

@Test
Expand All @@ -90,7 +90,7 @@ public void afterModelCallback_returnsEmptyMaybe() {
@Test
public void onModelErrorCallback_returnsEmptyMaybe() {
plugin
.onModelErrorCallback(callbackContext, llmRequest, new RuntimeException())
.onModelErrorCallback(callbackContext, llmRequestBuilder, new RuntimeException())
.test()
.assertResult();
}
Expand Down
17 changes: 9 additions & 8 deletions core/src/test/java/com/google/adk/plugins/LoggingPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ public class LoggingPluginTest {
.actions(EventActions.builder().build())
.longRunningToolIds(Optional.empty())
.build();
private final LlmRequest llmRequest =
LlmRequest.builder().model("default").contents(ImmutableList.of()).build();
private final LlmRequest.Builder llmRequestBuilder =
LlmRequest.builder().model("default").contents(ImmutableList.of());
private final LlmResponse llmResponse = LlmResponse.builder().build();
private final ImmutableMap<String, Object> toolArgs = ImmutableMap.of();
private final ImmutableMap<String, Object> toolResult = ImmutableMap.of();
Expand Down Expand Up @@ -175,7 +175,10 @@ public void afterAgentCallback_runsWithoutError() {

@Test
public void beforeModelCallback_runsWithoutError() {
loggingPlugin.beforeModelCallback(mockCallbackContext, llmRequest).test().assertComplete();
loggingPlugin
.beforeModelCallback(mockCallbackContext, llmRequestBuilder)
.test()
.assertComplete();
}

@Test
Expand All @@ -184,8 +187,7 @@ public void beforeModelCallback_longSystemInstruction() {
.beforeModelCallback(
mockCallbackContext,
LlmRequest.builder()
.appendInstructions(ImmutableList.of("all work and no play".repeat(1000)))
.build())
.appendInstructions(ImmutableList.of("all work and no play".repeat(1000))))
.test()
.assertComplete();
}
Expand All @@ -194,8 +196,7 @@ public void beforeModelCallback_longSystemInstruction() {
public void beforeModelCallback_tools() {
loggingPlugin
.beforeModelCallback(
mockCallbackContext,
LlmRequest.builder().appendTools(ImmutableList.of(mockTool)).build())
mockCallbackContext, LlmRequest.builder().appendTools(ImmutableList.of(mockTool)))
.test()
.assertComplete();
}
Expand Down Expand Up @@ -231,7 +232,7 @@ public void afterModelCallback_usageMetadata() {
@Test
public void onModelErrorCallback_runsWithoutError() {
loggingPlugin
.onModelErrorCallback(mockCallbackContext, llmRequest, throwable)
.onModelErrorCallback(mockCallbackContext, llmRequestBuilder, throwable)
.test()
.assertComplete();
}
Expand Down
12 changes: 6 additions & 6 deletions core/src/test/java/com/google/adk/plugins/PluginManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,18 @@ public void runAfterAgentCallback_singlePlugin() {
@Test
public void runBeforeModelCallback_singlePlugin() {
CallbackContext mockCallbackContext = mock(CallbackContext.class);
LlmRequest llmRequest = LlmRequest.builder().build();
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
LlmResponse llmResponse = LlmResponse.builder().build();

when(plugin1.beforeModelCallback(any(), any())).thenReturn(Maybe.just(llmResponse));
pluginManager.registerPlugin(plugin1);

pluginManager
.runBeforeModelCallback(mockCallbackContext, llmRequest)
.runBeforeModelCallback(mockCallbackContext, llmRequestBuilder)
.test()
.assertResult(llmResponse);

verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequest);
verify(plugin1).beforeModelCallback(mockCallbackContext, llmRequestBuilder);
}

@Test
Expand All @@ -269,19 +269,19 @@ public void runAfterModelCallback_singlePlugin() {
@Test
public void runOnModelErrorCallback_singlePlugin() {
CallbackContext mockCallbackContext = mock(CallbackContext.class);
LlmRequest llmRequest = LlmRequest.builder().build();
LlmRequest.Builder llmRequestBuilder = LlmRequest.builder();
Throwable mockThrowable = mock(Throwable.class);
LlmResponse llmResponse = LlmResponse.builder().build();

when(plugin1.onModelErrorCallback(any(), any(), any())).thenReturn(Maybe.just(llmResponse));
pluginManager.registerPlugin(plugin1);

pluginManager
.runOnModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable)
.runOnModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable)
.test()
.assertResult(llmResponse);

verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequest, mockThrowable);
verify(plugin1).onModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable);
}

@Test
Expand Down
6 changes: 3 additions & 3 deletions dev/src/main/java/com/google/adk/plugins/ReplayPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public Maybe<Content> beforeRunCallback(InvocationContext invocationContext) {

@Override
public Maybe<LlmResponse> beforeModelCallback(
CallbackContext callbackContext, LlmRequest llmRequest) {
CallbackContext callbackContext, LlmRequest.Builder llmRequest) {
if (!isReplayModeOn(callbackContext)) {
return Maybe.empty();
}
Expand Down Expand Up @@ -261,7 +261,7 @@ private Recording getNextRecordingForAgent(InvocationReplayState state, String a
}

private LlmRecording verifyAndGetNextLlmRecordingForAgent(
InvocationReplayState state, String agentName, LlmRequest llmRequest) {
InvocationReplayState state, String agentName, LlmRequest.Builder llmRequest) {
int currentAgentIndex = state.getAgentReplayIndex(agentName);
Recording expectedRecording = getNextRecordingForAgent(state, agentName);

Expand All @@ -278,7 +278,7 @@ private LlmRecording verifyAndGetNextLlmRecordingForAgent(
// Strict verification of LLM request
if (llmRecording.llmRequest().isPresent()) {
verifyLlmRequestMatch(
llmRecording.llmRequest().get(), llmRequest, agentName, currentAgentIndex);
llmRecording.llmRequest().get(), llmRequest.build(), agentName, currentAgentIndex);
}

return llmRecording;
Expand Down
10 changes: 4 additions & 6 deletions dev/src/test/java/com/google/adk/plugins/ReplayPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,15 @@ void beforeModelCallback_withMatchingRecording_returnsRecordedResponse() throws
when(callbackContext.invocationId()).thenReturn("test-invocation");
when(callbackContext.agentName()).thenReturn("test_agent");

LlmRequest request =
var request =
LlmRequest.builder()
.model("gemini-2.0-flash")
.contents(
ImmutableList.of(
Content.builder()
.role("user")
.parts(Part.builder().text("Hello").build())
.build()))
.build();
.build()));

// Step 4: Verify expected response is returned
var result = plugin.beforeModelCallback(callbackContext, request).blockingGet();
Expand Down Expand Up @@ -162,16 +161,15 @@ void beforeModelCallback_requestMismatch_returnsEmpty() throws Exception {
when(callbackContext.invocationId()).thenReturn("test-invocation");
when(callbackContext.agentName()).thenReturn("test_agent");

LlmRequest request =
var request =
LlmRequest.builder()
.model("gemini-2.0-flash") // Different model
.contents(
ImmutableList.of(
Content.builder()
.role("user")
.parts(Part.builder().text("Hello").build())
.build()))
.build();
.build()));

// Step 4: Verify result is empty
var result = plugin.beforeModelCallback(callbackContext, request).blockingGet();
Expand Down