Skip to content

Commit 2906eb5

Browse files
google-genai-botcopybara-github
authored andcommitted
fix:Update HITL/Tool workflows to correctly pause and resume runner operations
PiperOrigin-RevId: 840816599
1 parent 4b6cf0b commit 2906eb5

File tree

9 files changed

+666
-322
lines changed

9 files changed

+666
-322
lines changed

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
package com.google.adk.agents;
1818

1919
import com.google.adk.artifacts.BaseArtifactService;
20+
import com.google.adk.events.Event;
21+
import com.google.adk.flows.llmflows.ResumabilityConfig;
2022
import com.google.adk.memory.BaseMemoryService;
2123
import com.google.adk.models.LlmCallsLimitExceededException;
2224
import com.google.adk.plugins.PluginManager;
2325
import com.google.adk.sessions.BaseSessionService;
2426
import com.google.adk.sessions.Session;
27+
import com.google.common.collect.ImmutableSet;
2528
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2629
import com.google.errorprone.annotations.InlineMe;
2730
import com.google.genai.types.Content;
31+
import com.google.genai.types.FunctionCall;
2832
import java.util.Map;
2933
import java.util.Objects;
3034
import java.util.Optional;
@@ -45,6 +49,7 @@ public class InvocationContext {
4549
private final Session session;
4650
private final Optional<Content> userContent;
4751
private final RunConfig runConfig;
52+
private final ResumabilityConfig resumabilityConfig;
4853
private final InvocationCostManager invocationCostManager = new InvocationCostManager();
4954

5055
private Optional<String> branch;
@@ -64,6 +69,7 @@ private InvocationContext(Builder builder) {
6469
this.userContent = builder.userContent;
6570
this.runConfig = builder.runConfig;
6671
this.endInvocation = builder.endInvocation;
72+
this.resumabilityConfig = builder.resumabilityConfig;
6773
}
6874

6975
/**
@@ -207,6 +213,7 @@ public static InvocationContext copyOf(InvocationContext other) {
207213
.userContent(other.userContent)
208214
.runConfig(other.runConfig)
209215
.endInvocation(other.endInvocation)
216+
.resumabilityConfig(other.resumabilityConfig)
210217
.build();
211218
newContext.activeStreamingTools.putAll(other.activeStreamingTools);
212219
return newContext;
@@ -248,10 +255,8 @@ public String invocationId() {
248255
}
249256

250257
/**
251-
* Sets the branch ID for the current invocation. A branch represents a fork in the conversation
258+
* Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation
252259
* history.
253-
*
254-
* @param branch the branch ID, or null to clear it
255260
*/
256261
public void branch(@Nullable String branch) {
257262
this.branch = Optional.ofNullable(branch);
@@ -270,11 +275,7 @@ public BaseAgent agent() {
270275
return agent;
271276
}
272277

273-
/**
274-
* Sets the agent being invoked. This is useful when delegating to a sub-agent.
275-
*
276-
* @param agent the agent to set
277-
*/
278+
/** Sets the [agent] being invoked. This is useful when delegating to a sub-agent. */
278279
public void agent(BaseAgent agent) {
279280
this.agent = agent;
280281
}
@@ -302,11 +303,7 @@ public boolean endInvocation() {
302303
return endInvocation;
303304
}
304305

305-
/**
306-
* Sets whether this invocation should be ended.
307-
*
308-
* @param endInvocation true if the invocation should end, false otherwise
309-
*/
306+
/** Sets whether this invocation should be ended. */
310307
public void setEndInvocation(boolean endInvocation) {
311308
this.endInvocation = endInvocation;
312309
}
@@ -336,6 +333,28 @@ public void incrementLlmCallsCount() throws LlmCallsLimitExceededException {
336333
this.invocationCostManager.incrementAndEnforceLlmCallsLimit(this.runConfig);
337334
}
338335

336+
/** Returns whether the current invocation is resumable. */
337+
public boolean isResumable() {
338+
return resumabilityConfig.isResumable();
339+
}
340+
341+
/** Returns whether to pause the invocation right after this [event]. */
342+
public boolean shouldPauseInvocation(Event event) {
343+
if (!isResumable()) {
344+
return false;
345+
}
346+
347+
var longRunningToolIds = event.longRunningToolIds().orElse(ImmutableSet.of());
348+
if (longRunningToolIds.isEmpty()) {
349+
return false;
350+
}
351+
352+
return event.functionCalls().stream()
353+
.map(FunctionCall::id)
354+
.flatMap(Optional::stream)
355+
.anyMatch(functionCallId -> longRunningToolIds.contains(functionCallId));
356+
}
357+
339358
private static class InvocationCostManager {
340359
private int numberOfLlmCalls = 0;
341360

@@ -366,6 +385,7 @@ public static class Builder {
366385
private Optional<Content> userContent = Optional.empty();
367386
private RunConfig runConfig = RunConfig.builder().build();
368387
private boolean endInvocation = false;
388+
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();
369389

370390
/**
371391
* Sets the session service for managing session state.
@@ -553,6 +573,18 @@ public Builder endInvocation(boolean endInvocation) {
553573
return this;
554574
}
555575

576+
/**
577+
* Sets the resumability configuration for the current agent run.
578+
*
579+
* @param resumabilityConfig the resumability configuration.
580+
* @return this builder instance for chaining.
581+
*/
582+
@CanIgnoreReturnValue
583+
public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) {
584+
this.resumabilityConfig = resumabilityConfig;
585+
return this;
586+
}
587+
556588
/**
557589
* Builds the {@link InvocationContext} instance.
558590
*
@@ -584,7 +616,8 @@ public boolean equals(Object o) {
584616
&& Objects.equals(agent, that.agent)
585617
&& Objects.equals(session, that.session)
586618
&& Objects.equals(userContent, that.userContent)
587-
&& Objects.equals(runConfig, that.runConfig);
619+
&& Objects.equals(runConfig, that.runConfig)
620+
&& Objects.equals(resumabilityConfig, that.resumabilityConfig);
588621
}
589622

590623
@Override
@@ -602,6 +635,7 @@ public int hashCode() {
602635
session,
603636
userContent,
604637
runConfig,
605-
endInvocation);
638+
endInvocation,
639+
resumabilityConfig);
606640
}
607641
}

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,16 @@ private void maybeSaveOutputToState(Event event) {
604604

605605
@Override
606606
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
607-
return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState);
607+
return llmFlow
608+
.run(invocationContext)
609+
.concatMap(
610+
event -> {
611+
this.maybeSaveOutputToState(event);
612+
if (invocationContext.shouldPauseInvocation(event)) {
613+
return Flowable.just(event).concatWith(Flowable.empty());
614+
}
615+
return Flowable.just(event);
616+
});
608617
}
609618

610619
@Override

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,36 +156,9 @@ protected Flowable<Event> postprocess(
156156
}
157157

158158
return currentLlmResponse.flatMapPublisher(
159-
updatedResponse -> {
160-
Flowable<Event> processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables));
161-
162-
if (updatedResponse.content().isEmpty()
163-
&& updatedResponse.errorCode().isEmpty()
164-
&& !updatedResponse.interrupted().orElse(false)
165-
&& !updatedResponse.turnComplete().orElse(false)) {
166-
return processorEvents;
167-
}
168-
169-
Event modelResponseEvent =
170-
buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse);
171-
172-
Flowable<Event> modelEventStream = Flowable.just(modelResponseEvent);
173-
174-
if (modelResponseEvent.functionCalls().isEmpty()) {
175-
return processorEvents.concatWith(modelEventStream);
176-
}
177-
178-
Maybe<Event> maybeFunctionCallEvent;
179-
if (context.runConfig().streamingMode() == StreamingMode.BIDI) {
180-
maybeFunctionCallEvent =
181-
Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools());
182-
} else {
183-
maybeFunctionCallEvent =
184-
Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools());
185-
}
186-
187-
return processorEvents.concatWith(modelEventStream).concatWith(maybeFunctionCallEvent);
188-
});
159+
updatedResponse ->
160+
buildPostprocessingEvents(
161+
updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest));
189162
}
190163

191164
/**
@@ -623,6 +596,45 @@ public void onError(Throwable e) {
623596
*
624597
* @return A fully constructed {@link Event} representing the LLM response.
625598
*/
599+
private Flowable<Event> buildPostprocessingEvents(
600+
LlmResponse updatedResponse,
601+
List<Iterable<Event>> eventIterables,
602+
InvocationContext context,
603+
Event baseEventForLlmResponse,
604+
LlmRequest llmRequest) {
605+
Flowable<Event> processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables));
606+
if (updatedResponse.content().isEmpty()
607+
&& updatedResponse.errorCode().isEmpty()
608+
&& !updatedResponse.interrupted().orElse(false)
609+
&& !updatedResponse.turnComplete().orElse(false)) {
610+
return processorEvents;
611+
}
612+
613+
Event modelResponseEvent =
614+
buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse);
615+
if (modelResponseEvent.functionCalls().isEmpty()) {
616+
return processorEvents.concatWith(Flowable.just(modelResponseEvent));
617+
}
618+
619+
Maybe<Event> maybeFunctionResponseEvent =
620+
context.runConfig().streamingMode() == StreamingMode.BIDI
621+
? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools())
622+
: Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools());
623+
624+
Flowable<Event> functionEvents =
625+
maybeFunctionResponseEvent.flatMapPublisher(
626+
functionResponseEvent -> {
627+
Optional<Event> toolConfirmationEvent =
628+
Functions.generateRequestConfirmationEvent(
629+
context, modelResponseEvent, functionResponseEvent);
630+
return toolConfirmationEvent.isPresent()
631+
? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent)
632+
: Flowable.just(functionResponseEvent);
633+
});
634+
635+
return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents);
636+
}
637+
626638
private Event buildModelResponseEvent(
627639
Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) {
628640
Event.Builder eventBuilder =
@@ -641,10 +653,13 @@ private Event buildModelResponseEvent(
641653

642654
Event event = eventBuilder.build();
643655

656+
logger.info("event: {} functionCalls: {}", event, event.functionCalls());
657+
644658
if (!event.functionCalls().isEmpty()) {
645659
Functions.populateClientFunctionCallId(event);
646660
Set<String> longRunningToolIds =
647661
Functions.getLongRunningFunctionCalls(event.functionCalls(), llmRequest.tools());
662+
logger.info("longRunningToolIds: {}", longRunningToolIds);
648663
if (!longRunningToolIds.isEmpty()) {
649664
event.setLongRunningToolIds(Optional.of(longRunningToolIds));
650665
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.adk.flows.llmflows;
17+
18+
/**
19+
* An app contains Resumability configuration for the agents.
20+
*
21+
* @param isResumable Whether the app is resumable.
22+
*/
23+
public record ResumabilityConfig(boolean isResumable) {
24+
25+
/** Creates a new {@code ResumabilityConfig} with resumability disabled. */
26+
public ResumabilityConfig() {
27+
this(false);
28+
}
29+
}

0 commit comments

Comments
 (0)