Skip to content

Commit d766ffa

Browse files
committed
support parallel agent
1 parent b3ca86e commit d766ffa

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

core/src/main/java/com/google/adk/agents/ParallelAgent.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
2121
import com.google.adk.events.Event;
2222
import io.reactivex.rxjava3.core.Flowable;
23+
import io.reactivex.rxjava3.core.Scheduler;
24+
import io.reactivex.rxjava3.schedulers.Schedulers;
2325
import java.util.ArrayList;
2426
import java.util.List;
2527
import org.slf4j.Logger;
@@ -35,6 +37,7 @@
3537
public class ParallelAgent extends BaseAgent {
3638

3739
private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class);
40+
private final Scheduler scheduler;
3841

3942
/**
4043
* Constructor for ParallelAgent.
@@ -44,24 +47,34 @@ public class ParallelAgent extends BaseAgent {
4447
* @param subAgents The list of sub-agents to run in parallel.
4548
* @param beforeAgentCallback Optional callback before the agent runs.
4649
* @param afterAgentCallback Optional callback after the agent runs.
50+
* @param scheduler The scheduler to use for parallel execution.
4751
*/
4852
private ParallelAgent(
4953
String name,
5054
String description,
5155
List<? extends BaseAgent> subAgents,
5256
List<Callbacks.BeforeAgentCallback> beforeAgentCallback,
53-
List<Callbacks.AfterAgentCallback> afterAgentCallback) {
57+
List<Callbacks.AfterAgentCallback> afterAgentCallback,
58+
Scheduler scheduler) {
5459

5560
super(name, description, subAgents, beforeAgentCallback, afterAgentCallback);
61+
this.scheduler = scheduler;
5662
}
5763

5864
/** Builder for {@link ParallelAgent}. */
5965
public static class Builder extends BaseAgent.Builder<Builder> {
6066

67+
private Scheduler scheduler = Schedulers.io();
68+
69+
public Builder scheduler(Scheduler scheduler) {
70+
this.scheduler = scheduler;
71+
return this;
72+
}
73+
6174
@Override
6275
public ParallelAgent build() {
6376
return new ParallelAgent(
64-
name, description, subAgents, beforeAgentCallback, afterAgentCallback);
77+
name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler);
6578
}
6679
}
6780

@@ -131,7 +144,7 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
131144

132145
List<Flowable<Event>> agentFlowables = new ArrayList<>();
133146
for (BaseAgent subAgent : currentSubAgents) {
134-
agentFlowables.add(subAgent.runAsync(invocationContext));
147+
agentFlowables.add(subAgent.runAsync(invocationContext).subscribeOn(scheduler));
135148
}
136149
return Flowable.merge(agentFlowables);
137150
}

core/src/test/java/com/google/adk/agents/ParallelAgentTest.java

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
import com.google.genai.types.Content;
2626
import com.google.genai.types.Part;
2727
import io.reactivex.rxjava3.core.Flowable;
28+
import io.reactivex.rxjava3.core.Scheduler;
2829
import io.reactivex.rxjava3.schedulers.Schedulers;
30+
import io.reactivex.rxjava3.schedulers.TestScheduler;
31+
import io.reactivex.rxjava3.subscribers.TestSubscriber;
2932
import java.util.List;
3033
import org.junit.Test;
3134
import org.junit.runner.RunWith;
@@ -36,10 +39,16 @@ public final class ParallelAgentTest {
3639

3740
static class TestingAgent extends BaseAgent {
3841
private final long delayMillis;
42+
private final Scheduler scheduler;
3943

4044
private TestingAgent(String name, String description, long delayMillis) {
45+
this(name, description, delayMillis, Schedulers.computation());
46+
}
47+
48+
private TestingAgent(String name, String description, long delayMillis, Scheduler scheduler) {
4149
super(name, description, ImmutableList.of(), null, null);
4250
this.delayMillis = delayMillis;
51+
this.scheduler = scheduler;
4352
}
4453

4554
@Override
@@ -55,7 +64,7 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
5564
.build());
5665

5766
if (delayMillis > 0) {
58-
return event.delay(delayMillis, MILLISECONDS, Schedulers.computation());
67+
return event.delay(delayMillis, MILLISECONDS, scheduler);
5968
}
6069
return event;
6170
}
@@ -110,4 +119,79 @@ public void runAsync_noSubAgents_returnsEmptyFlowable() {
110119

111120
assertThat(events).isEmpty();
112121
}
122+
123+
static class BlockingAgent extends BaseAgent {
124+
private final long sleepMillis;
125+
126+
private BlockingAgent(String name, long sleepMillis) {
127+
super(name, "Blocking Agent", ImmutableList.of(), null, null);
128+
this.sleepMillis = sleepMillis;
129+
}
130+
131+
@Override
132+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
133+
return Flowable.fromCallable(
134+
() -> {
135+
Thread.sleep(sleepMillis);
136+
return Event.builder()
137+
.author(name())
138+
.branch(invocationContext.branch().orElse(null))
139+
.invocationId(invocationContext.invocationId())
140+
.content(Content.fromParts(Part.fromText("Done")))
141+
.build();
142+
});
143+
}
144+
145+
@Override
146+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
147+
throw new UnsupportedOperationException("Not implemented");
148+
}
149+
}
150+
151+
@Test
152+
public void runAsync_blockingSubAgents_shouldExecuteInParallel() {
153+
long sleepTime = 1000;
154+
BlockingAgent agent1 = new BlockingAgent("agent1", sleepTime);
155+
BlockingAgent agent2 = new BlockingAgent("agent2", sleepTime);
156+
157+
ParallelAgent parallelAgent =
158+
ParallelAgent.builder().name("parallel_agent").subAgents(agent1, agent2).build();
159+
160+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
161+
162+
long startTime = System.currentTimeMillis();
163+
List<Event> events = parallelAgent.runAsync(invocationContext).toList().blockingGet();
164+
long duration = System.currentTimeMillis() - startTime;
165+
166+
assertThat(events).hasSize(2);
167+
// If parallel, duration should be less than 1.5 * sleepTime (1500ms).
168+
assertThat(duration).isAtLeast(sleepTime);
169+
assertThat(duration).isLessThan((long) (1.5 * sleepTime));
170+
}
171+
172+
@Test
173+
public void runAsync_withTestScheduler_usesVirtualTime() {
174+
TestScheduler testScheduler = new TestScheduler();
175+
long delayMillis = 1000;
176+
TestingAgent agent =
177+
new TestingAgent("delayed_agent", "Delayed Agent", delayMillis, testScheduler);
178+
179+
ParallelAgent parallelAgent =
180+
ParallelAgent.builder()
181+
.name("parallel_agent")
182+
.subAgents(agent)
183+
.scheduler(testScheduler)
184+
.build();
185+
186+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
187+
188+
TestSubscriber<Event> testSubscriber = parallelAgent.runAsync(invocationContext).test();
189+
190+
testScheduler.advanceTimeBy(delayMillis - 100, MILLISECONDS);
191+
testSubscriber.assertNoValues();
192+
testSubscriber.assertNotComplete();
193+
testScheduler.advanceTimeBy(200, MILLISECONDS);
194+
testSubscriber.assertValueCount(1);
195+
testSubscriber.assertComplete();
196+
}
113197
}

0 commit comments

Comments
 (0)