From 184556e7f6f089e2f8864dcecefc9971a3cbe96a Mon Sep 17 00:00:00 2001 From: Alex Plate Date: Wed, 20 May 2026 17:56:36 +0300 Subject: [PATCH] Make StdioTransport accept Flow + suspend writer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Promote the Flow-based pair to the primary constructor and keep the existing Source/Sink one as a back-compat secondary, marked @Deprecated: the blocking variant pins a dispatcher thread per readLine and would saturate the I/O pool under high agent concurrency. Fix close() to cancel childScope. Previously close() only closed the channels and invoked closeHandler; the Source/Sink path worked because closing the underlying streams unblocked readLine, but the Flow path has no equivalent unblock — close() would leave the read job parked inside input.collect and the transport stuck in CLOSING. Cancelling childScope cooperatively unwinds the collect, the read job exits, joinAll returns, and the finally block sets state to CLOSED. The Source/Sink path is unaffected (the cancel is a no-op once the job has already exited). Add StdioTransportFlowTest covering the new primary path: round-trip send/receive, JSON-skip behaviour, input-flow completion, input-flow errors, the output exception contract (IOException = clean, anything else = onError), parent-scope cancellation, and the close-deadlock regression itself. Co-Authored-By: Claude Opus 4.7 (1M context) --- acp/api/acp.api | 2 + .../transport/StdioTransport.kt | 143 +++++++++--- .../transport/StdioTransportFlowTest.kt | 217 ++++++++++++++++++ 3 files changed, 331 insertions(+), 31 deletions(-) create mode 100644 acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt diff --git a/acp/api/acp.api b/acp/api/acp.api index 3e85ba2..98ce982 100644 --- a/acp/api/acp.api +++ b/acp/api/acp.api @@ -501,6 +501,8 @@ public abstract class com/agentclientprotocol/transport/BaseTransport : com/agen } public final class com/agentclientprotocol/transport/StdioTransport : com/agentclientprotocol/transport/BaseTransport { + public fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Ljava/lang/String;)V + public synthetic fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/io/Source;Lkotlinx/io/Sink;Ljava/lang/String;)V public synthetic fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/io/Source;Lkotlinx/io/Sink;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close ()V diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt index 5ed59df..a14cf59 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt @@ -4,12 +4,12 @@ import com.agentclientprotocol.rpc.ACPJson import com.agentclientprotocol.rpc.JsonRpcMessage import com.agentclientprotocol.rpc.decodeJsonRpcMessage import com.agentclientprotocol.transport.Transport.State -import com.agentclientprotocol.util.checkCancelled import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.getAndUpdate -import kotlinx.coroutines.flow.update import kotlinx.io.* import kotlinx.serialization.encodeToString @@ -21,47 +21,94 @@ private val logger = KotlinLogging.logger {} * This transport communicates over standard input/output streams, * which is commonly used for command-line agents. */ -public class StdioTransport( +public class StdioTransport private constructor( private val parentScope: CoroutineScope, private val ioDispatcher: CoroutineDispatcher, - private val input: Source, - private val output: Sink, - private val name: String = StdioTransport::class.simpleName!!, + private val input: Flow, + private val output: suspend (String) -> Unit, + private val name: String, + private val closeHandler: () -> Unit, ) : BaseTransport() { + + /** + * Primary [Flow]-based constructor. + * + * @param parentScope coroutine scope for the transport's lifecycle + * @param ioDispatcher dispatcher used for the read and write coroutines + * @param input cold flow of incoming NDJSON lines. Cancellation of the + * transport cancels collection; use [Flow.onCompletion] to react to + * that cancellation if you need to release upstream resources. + * @param output suspending writer invoked once per outgoing line. The + * implementation owns framing (newline) and flushing semantics. To + * signal that the underlying transport has closed cleanly, throw + * [IllegalStateException] or [IOException]; the write loop will exit + * without firing an error. Any other exception is reported via + * [Transport.onError]. + * @param name optional name used in coroutine names and log messages + */ + public constructor( + parentScope: CoroutineScope, + ioDispatcher: CoroutineDispatcher, + input: Flow, + output: suspend (String) -> Unit, + name: String = StdioTransport::class.simpleName!!, + ) : this(parentScope, ioDispatcher, input, output, name, closeHandler = {}) + + /** + * Back-compat constructor backed by blocking [Source] / [Sink]. + * + * Deprecated because this variant blocks the dispatcher thread for the + * duration of each [Source.readLine]. Under high agent concurrency this can + * saturate the I/O dispatcher (e.g. `Dispatchers.IO`) and cascade into + * freezes if other consumers schedule blocking work on the same pool. + * + * When this constructor is removed, the [Flow]-based constructor should be + * promoted to the primary one and [closeHandler] dropped from its parameter + * list (cancelling [childScope] in [close] is already enough to unwind the + * Flow-based read path). + * + * Callers should adapt their blocking [Source] / [Sink] into a + * [Flow]`` and a `suspend (String) -> Unit` at the call site and + * use the [Flow]-based constructor instead. + */ + @Deprecated( + message = "Blocking Source/Sink pins a dispatcher thread per read and forces an extra closeHandler. " + + "Adapt your I/O into Flow + suspend (String) -> Unit and use the Flow-based constructor.", + level = DeprecationLevel.WARNING, + ) + public constructor( + parentScope: CoroutineScope, + ioDispatcher: CoroutineDispatcher, + input: Source, + output: Sink, + name: String = StdioTransport::class.simpleName!!, + ) : this( + parentScope, ioDispatcher, + sourceAsLineFlow(input), sinkAsLineWriter(output), + name, + closeHandler = makeSourceSinkCloseHandler(input, output), + ) + private val childScope = CoroutineScope(parentScope.coroutineContext + SupervisorJob(parentScope.coroutineContext[Job]) + CoroutineName(name)) private val receiveChannel = Channel(Channel.UNLIMITED) private val sendChannel = Channel(Channel.UNLIMITED) - + override fun start() { if (_state.getAndUpdate { State.STARTING } != State.CREATED) error("Transport is not in ${State.CREATED.name} state") // Start reading messages from input childScope.launch(CoroutineName("$name.join-jobs")) { val readJob = launch(ioDispatcher + CoroutineName("$name.read-from-input")) { try { - while (currentCoroutineContext().isActive) { + // ACP assumes working with ND Json (new line delimited Json) when working over stdio + input.collect { line -> currentCoroutineContext().ensureActive() - // ACP assumes working with ND Json (new line delimited Json) when working over stdio - val line = try { - input.readLine() - } catch (e: IllegalStateException) { - logger.trace(e) { "Input stream closed" } - break - } catch (e: IOException) { - logger.trace(e) { "Input stream likely closed" } - break - } - if (line == null) { - // End of stream - logger.trace { "End of stream" } - break - } val jsonRpcMessage = try { decodeJsonRpcMessage(line) } catch (t: Throwable) { logger.trace(t) { "Failed to decode JSON message: $line" } - continue + return@collect } logger.trace { "Sending message to channel: $jsonRpcMessage" } fireMessage(jsonRpcMessage) @@ -84,9 +131,7 @@ public class StdioTransport( for (message in sendChannel) { val encoded = ACPJson.encodeToString(message) try { - output.writeString(encoded) - output.writeString("\n") - output.flush() + output(encoded) } catch (e: IllegalStateException) { logger.trace(e) { "Output stream closed" } break @@ -129,7 +174,7 @@ public class StdioTransport( } } } - + override fun send(message: JsonRpcMessage) { logger.trace { "Sending message: $message" } val channelResult = sendChannel.trySend(message) @@ -150,7 +195,43 @@ public class StdioTransport( if (sendChannel.close()) logger.trace { "Send channel closed" } if (receiveChannel.close()) logger.trace { "Receive channel closed" } - runCatching { input.close() }.onFailure { logger.warn(it) { "Exception when closing input stream" } } - runCatching { output.close() }.onFailure { logger.warn(it) { "Exception when closing output stream" } } + runCatching { closeHandler() }.onFailure { logger.warn(it) { "Exception in close handler" } } + + // Unwind the read/write coroutines. The Source/Sink back-compat path relies + // on [closeHandler] closing the underlying streams to unblock readLine, but + // the Flow-based path has no equivalent unblock — without cancelling here + // the read job would stay parked inside input.collect and the transport + // would be stuck in CLOSING. + childScope.cancel() } -} \ No newline at end of file +} + +private fun sourceAsLineFlow(source: Source): Flow = flow { + while (true) { + val line = try { + source.readLine() + } catch (e: IllegalStateException) { + logger.trace(e) { "Input stream closed" } + break + } catch (e: IOException) { + logger.trace(e) { "Input stream likely closed" } + break + } + if (line == null) { + logger.trace { "End of stream" } + break + } + emit(line) + } +} + +private fun sinkAsLineWriter(sink: Sink): suspend (String) -> Unit = { line -> + sink.writeString(line) + sink.writeString("\n") + sink.flush() +} + +private fun makeSourceSinkCloseHandler(source: Source, sink: Sink): () -> Unit = { + runCatching { source.close() }.onFailure { logger.warn(it) { "Exception when closing input stream" } } + runCatching { sink.close() }.onFailure { logger.warn(it) { "Exception when closing output stream" } } +} diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt new file mode 100644 index 0000000..2cd923e --- /dev/null +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt @@ -0,0 +1,217 @@ +package com.agentclientprotocol.transport + +import com.agentclientprotocol.rpc.JsonRpcNotification +import com.agentclientprotocol.rpc.JsonRpcRequest +import com.agentclientprotocol.rpc.MethodName +import com.agentclientprotocol.rpc.RequestId +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.io.IOException +import kotlin.test.AfterTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.test.fail +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +/** + * Tests for the primary [StdioTransport] [Flow]-based constructor. + * + * The deprecated Source/Sink constructor has its own coverage in + * [StdioTransportTest]; tests here exercise paths that only matter for the Flow + * primary path (input flow termination/error, custom writer exception contract, + * close-without-stream-close). + */ +class StdioTransportFlowTest { + private val scope = CoroutineScope(SupervisorJob()) + + @AfterTest + fun tearDown() { + scope.cancel() + } + + @Test + fun `close on never-completing input reaches CLOSED`(): Unit = runBlocking { + val transport = makeTransport( + input = Channel(Channel.UNLIMITED).receiveAsFlow(), + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.close() + transport.expectState( + Transport.State.CLOSED, + message = "close() did not drive state to CLOSED for a never-completing input Flow", + ) + } + + @Test + fun `send writes encoded line to output`(): Unit = runBlocking { + val written = Channel(Channel.UNLIMITED) + val transport = makeTransport( + output = { line -> written.send(line) }, + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.send(JsonRpcRequest(RequestId.create(7), MethodName("ping"))) + + val line = withTimeout(1.seconds) { written.receive() } + assertTrue(line.contains("\"method\":\"ping\""), "encoded line should carry the method, was: $line") + assertTrue(line.contains("\"id\":7"), "encoded line should carry the id, was: $line") + } + + @Test + fun `input emission fires onMessage`(): Unit = runBlocking { + val inputChannel = Channel(Channel.UNLIMITED) + val transport = makeTransport(input = inputChannel.receiveAsFlow()) + val received = transport.asMessageChannel() + transport.start() + transport.expectState(Transport.State.STARTED) + + inputChannel.send("""{"jsonrpc":"2.0","method":"hello"}""") + + val message = withTimeout(1.seconds) { received.receive() } + assertTrue(message is JsonRpcNotification) + assertEquals(MethodName("hello"), message.method) + } + + @Test + fun `invalid JSON lines are skipped and valid ones still processed`(): Unit = runBlocking { + val inputChannel = Channel(Channel.UNLIMITED) + val transport = makeTransport(input = inputChannel.receiveAsFlow()) + val received = transport.asMessageChannel() + transport.start() + transport.expectState(Transport.State.STARTED) + + inputChannel.send("not json at all") + inputChannel.send("") + inputChannel.send("""{"jsonrpc":"2.0","method":"after-garbage","id":1}""") + + val message = withTimeout(1.seconds) { received.receive() } + assertTrue(message is JsonRpcRequest) + assertEquals(MethodName("after-garbage"), message.method) + } + + @Test + fun `input flow completing drives state to CLOSED`(): Unit = runBlocking { + // Flow that emits one valid message then completes naturally. + val transport = makeTransport( + input = flowOf("""{"jsonrpc":"2.0","method":"once"}"""), + ) + val received = transport.asMessageChannel() + transport.start() + + val message = withTimeout(1.seconds) { received.receive() } + assertNotNull(message) + + transport.expectState(Transport.State.CLOSED, message = "completion of input flow should close transport") + } + + @Test + fun `input flow throwing drives state to CLOSED and fires onError`(): Unit = runBlocking { + val sentinel = IllegalStateException("upstream blew up") + val errors = mutableListOf() + val transport = makeTransport( + input = flow { throw sentinel }, + ).apply { onError { errors.add(it) } } + transport.start() + + transport.expectState(Transport.State.CLOSED, message = "input flow error should close transport") + assertTrue(errors.any { it === sentinel }, "expected upstream error to be reported via onError, got: $errors") + } + + @Test + fun `output IOException closes write loop without firing onError`(): Unit = runBlocking { + val errors = mutableListOf() + val transport = makeTransport( + output = { throw IOException("peer gone") }, + ).apply { onError { errors.add(it) } } + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.send(JsonRpcNotification(method = MethodName("ignored"))) + + transport.expectState(Transport.State.CLOSED, message = "IOException from output should close the transport") + assertTrue(errors.isEmpty(), "IOException should be treated as clean shutdown, got: $errors") + } + + @Test + fun `output unexpected exception fires onError and closes`(): Unit = runBlocking { + val sentinel = RuntimeException("writer bug") + val errors = mutableListOf() + val transport = makeTransport( + output = { throw sentinel }, + ).apply { onError { errors.add(it) } } + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.send(JsonRpcNotification(method = MethodName("ignored"))) + + transport.expectState(Transport.State.CLOSED, message = "unexpected output error should close the transport") + assertTrue(errors.any { it === sentinel }, "expected writer error to be reported via onError, got: $errors") + } + + @Test + fun `parent scope cancellation drives state to CLOSED`(): Unit = runBlocking { + val transport = makeTransport( + input = Channel(Channel.UNLIMITED).receiveAsFlow(), + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + scope.cancel() + transport.expectState(Transport.State.CLOSED, message = "parent scope cancellation should close transport") + } + + @Test + fun `concurrent sends are all delivered to output`(): Unit = runBlocking { + val written = Channel(Channel.UNLIMITED) + val transport = makeTransport( + output = { line -> written.send(line) }, + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + val sendJobs = (1..10).map { i -> + scope.launch { transport.send(JsonRpcNotification(method = MethodName("method$i"))) } + } + sendJobs.joinAll() + + val received = withTimeout(1.seconds) { List(10) { written.receive() } } + (1..10).forEach { i -> + assertTrue(received.any { it.contains("\"method\":\"method$i\"") }, "missing method$i in $received") + } + } + + private fun makeTransport( + input: Flow = Channel(Channel.UNLIMITED).receiveAsFlow(), + output: suspend (String) -> Unit = { /* discard */ }, + ): StdioTransport = StdioTransport( + parentScope = scope, + ioDispatcher = Dispatchers.IO, + input = input, + output = output, + ) + + private suspend fun Transport.expectState(state: Transport.State, timeout: Duration = 1.seconds, message: String? = null) { + val observed = mutableListOf() + try { + withTimeout(timeout) { + this@expectState.state + .onEach { observed.add(it) } + .first { it == state } + } + } catch (_: TimeoutCancellationException) { + fail("Timed out waiting for state $state after $timeout, observed: ${observed.joinToString { it.name }}${message?.let { " — $it" } ?: ""}") + } + } +}