diff --git a/acp/api/acp.api b/acp/api/acp.api index 3e85ba2..98ce982 100644 --- a/acp/api/acp.api +++ b/acp/api/acp.api @@ -501,6 +501,8 @@ public abstract class com/agentclientprotocol/transport/BaseTransport : com/agen } public final class com/agentclientprotocol/transport/StdioTransport : com/agentclientprotocol/transport/BaseTransport { + public fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Ljava/lang/String;)V + public synthetic fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/io/Source;Lkotlinx/io/Sink;Ljava/lang/String;)V public synthetic fun (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/io/Source;Lkotlinx/io/Sink;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close ()V diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt index 5ed59df..a14cf59 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/transport/StdioTransport.kt @@ -4,12 +4,12 @@ import com.agentclientprotocol.rpc.ACPJson import com.agentclientprotocol.rpc.JsonRpcMessage import com.agentclientprotocol.rpc.decodeJsonRpcMessage import com.agentclientprotocol.transport.Transport.State -import com.agentclientprotocol.util.checkCancelled import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.getAndUpdate -import kotlinx.coroutines.flow.update import kotlinx.io.* import kotlinx.serialization.encodeToString @@ -21,47 +21,94 @@ private val logger = KotlinLogging.logger {} * This transport communicates over standard input/output streams, * which is commonly used for command-line agents. */ -public class StdioTransport( +public class StdioTransport private constructor( private val parentScope: CoroutineScope, private val ioDispatcher: CoroutineDispatcher, - private val input: Source, - private val output: Sink, - private val name: String = StdioTransport::class.simpleName!!, + private val input: Flow, + private val output: suspend (String) -> Unit, + private val name: String, + private val closeHandler: () -> Unit, ) : BaseTransport() { + + /** + * Primary [Flow]-based constructor. + * + * @param parentScope coroutine scope for the transport's lifecycle + * @param ioDispatcher dispatcher used for the read and write coroutines + * @param input cold flow of incoming NDJSON lines. Cancellation of the + * transport cancels collection; use [Flow.onCompletion] to react to + * that cancellation if you need to release upstream resources. + * @param output suspending writer invoked once per outgoing line. The + * implementation owns framing (newline) and flushing semantics. To + * signal that the underlying transport has closed cleanly, throw + * [IllegalStateException] or [IOException]; the write loop will exit + * without firing an error. Any other exception is reported via + * [Transport.onError]. + * @param name optional name used in coroutine names and log messages + */ + public constructor( + parentScope: CoroutineScope, + ioDispatcher: CoroutineDispatcher, + input: Flow, + output: suspend (String) -> Unit, + name: String = StdioTransport::class.simpleName!!, + ) : this(parentScope, ioDispatcher, input, output, name, closeHandler = {}) + + /** + * Back-compat constructor backed by blocking [Source] / [Sink]. + * + * Deprecated because this variant blocks the dispatcher thread for the + * duration of each [Source.readLine]. Under high agent concurrency this can + * saturate the I/O dispatcher (e.g. `Dispatchers.IO`) and cascade into + * freezes if other consumers schedule blocking work on the same pool. + * + * When this constructor is removed, the [Flow]-based constructor should be + * promoted to the primary one and [closeHandler] dropped from its parameter + * list (cancelling [childScope] in [close] is already enough to unwind the + * Flow-based read path). + * + * Callers should adapt their blocking [Source] / [Sink] into a + * [Flow]`` and a `suspend (String) -> Unit` at the call site and + * use the [Flow]-based constructor instead. + */ + @Deprecated( + message = "Blocking Source/Sink pins a dispatcher thread per read and forces an extra closeHandler. " + + "Adapt your I/O into Flow + suspend (String) -> Unit and use the Flow-based constructor.", + level = DeprecationLevel.WARNING, + ) + public constructor( + parentScope: CoroutineScope, + ioDispatcher: CoroutineDispatcher, + input: Source, + output: Sink, + name: String = StdioTransport::class.simpleName!!, + ) : this( + parentScope, ioDispatcher, + sourceAsLineFlow(input), sinkAsLineWriter(output), + name, + closeHandler = makeSourceSinkCloseHandler(input, output), + ) + private val childScope = CoroutineScope(parentScope.coroutineContext + SupervisorJob(parentScope.coroutineContext[Job]) + CoroutineName(name)) private val receiveChannel = Channel(Channel.UNLIMITED) private val sendChannel = Channel(Channel.UNLIMITED) - + override fun start() { if (_state.getAndUpdate { State.STARTING } != State.CREATED) error("Transport is not in ${State.CREATED.name} state") // Start reading messages from input childScope.launch(CoroutineName("$name.join-jobs")) { val readJob = launch(ioDispatcher + CoroutineName("$name.read-from-input")) { try { - while (currentCoroutineContext().isActive) { + // ACP assumes working with ND Json (new line delimited Json) when working over stdio + input.collect { line -> currentCoroutineContext().ensureActive() - // ACP assumes working with ND Json (new line delimited Json) when working over stdio - val line = try { - input.readLine() - } catch (e: IllegalStateException) { - logger.trace(e) { "Input stream closed" } - break - } catch (e: IOException) { - logger.trace(e) { "Input stream likely closed" } - break - } - if (line == null) { - // End of stream - logger.trace { "End of stream" } - break - } val jsonRpcMessage = try { decodeJsonRpcMessage(line) } catch (t: Throwable) { logger.trace(t) { "Failed to decode JSON message: $line" } - continue + return@collect } logger.trace { "Sending message to channel: $jsonRpcMessage" } fireMessage(jsonRpcMessage) @@ -84,9 +131,7 @@ public class StdioTransport( for (message in sendChannel) { val encoded = ACPJson.encodeToString(message) try { - output.writeString(encoded) - output.writeString("\n") - output.flush() + output(encoded) } catch (e: IllegalStateException) { logger.trace(e) { "Output stream closed" } break @@ -129,7 +174,7 @@ public class StdioTransport( } } } - + override fun send(message: JsonRpcMessage) { logger.trace { "Sending message: $message" } val channelResult = sendChannel.trySend(message) @@ -150,7 +195,43 @@ public class StdioTransport( if (sendChannel.close()) logger.trace { "Send channel closed" } if (receiveChannel.close()) logger.trace { "Receive channel closed" } - runCatching { input.close() }.onFailure { logger.warn(it) { "Exception when closing input stream" } } - runCatching { output.close() }.onFailure { logger.warn(it) { "Exception when closing output stream" } } + runCatching { closeHandler() }.onFailure { logger.warn(it) { "Exception in close handler" } } + + // Unwind the read/write coroutines. The Source/Sink back-compat path relies + // on [closeHandler] closing the underlying streams to unblock readLine, but + // the Flow-based path has no equivalent unblock — without cancelling here + // the read job would stay parked inside input.collect and the transport + // would be stuck in CLOSING. + childScope.cancel() } -} \ No newline at end of file +} + +private fun sourceAsLineFlow(source: Source): Flow = flow { + while (true) { + val line = try { + source.readLine() + } catch (e: IllegalStateException) { + logger.trace(e) { "Input stream closed" } + break + } catch (e: IOException) { + logger.trace(e) { "Input stream likely closed" } + break + } + if (line == null) { + logger.trace { "End of stream" } + break + } + emit(line) + } +} + +private fun sinkAsLineWriter(sink: Sink): suspend (String) -> Unit = { line -> + sink.writeString(line) + sink.writeString("\n") + sink.flush() +} + +private fun makeSourceSinkCloseHandler(source: Source, sink: Sink): () -> Unit = { + runCatching { source.close() }.onFailure { logger.warn(it) { "Exception when closing input stream" } } + runCatching { sink.close() }.onFailure { logger.warn(it) { "Exception when closing output stream" } } +} diff --git a/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt b/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt new file mode 100644 index 0000000..2cd923e --- /dev/null +++ b/acp/src/jvmTest/kotlin/com/agentclientprotocol/transport/StdioTransportFlowTest.kt @@ -0,0 +1,217 @@ +package com.agentclientprotocol.transport + +import com.agentclientprotocol.rpc.JsonRpcNotification +import com.agentclientprotocol.rpc.JsonRpcRequest +import com.agentclientprotocol.rpc.MethodName +import com.agentclientprotocol.rpc.RequestId +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.io.IOException +import kotlin.test.AfterTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.test.fail +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +/** + * Tests for the primary [StdioTransport] [Flow]-based constructor. + * + * The deprecated Source/Sink constructor has its own coverage in + * [StdioTransportTest]; tests here exercise paths that only matter for the Flow + * primary path (input flow termination/error, custom writer exception contract, + * close-without-stream-close). + */ +class StdioTransportFlowTest { + private val scope = CoroutineScope(SupervisorJob()) + + @AfterTest + fun tearDown() { + scope.cancel() + } + + @Test + fun `close on never-completing input reaches CLOSED`(): Unit = runBlocking { + val transport = makeTransport( + input = Channel(Channel.UNLIMITED).receiveAsFlow(), + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.close() + transport.expectState( + Transport.State.CLOSED, + message = "close() did not drive state to CLOSED for a never-completing input Flow", + ) + } + + @Test + fun `send writes encoded line to output`(): Unit = runBlocking { + val written = Channel(Channel.UNLIMITED) + val transport = makeTransport( + output = { line -> written.send(line) }, + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.send(JsonRpcRequest(RequestId.create(7), MethodName("ping"))) + + val line = withTimeout(1.seconds) { written.receive() } + assertTrue(line.contains("\"method\":\"ping\""), "encoded line should carry the method, was: $line") + assertTrue(line.contains("\"id\":7"), "encoded line should carry the id, was: $line") + } + + @Test + fun `input emission fires onMessage`(): Unit = runBlocking { + val inputChannel = Channel(Channel.UNLIMITED) + val transport = makeTransport(input = inputChannel.receiveAsFlow()) + val received = transport.asMessageChannel() + transport.start() + transport.expectState(Transport.State.STARTED) + + inputChannel.send("""{"jsonrpc":"2.0","method":"hello"}""") + + val message = withTimeout(1.seconds) { received.receive() } + assertTrue(message is JsonRpcNotification) + assertEquals(MethodName("hello"), message.method) + } + + @Test + fun `invalid JSON lines are skipped and valid ones still processed`(): Unit = runBlocking { + val inputChannel = Channel(Channel.UNLIMITED) + val transport = makeTransport(input = inputChannel.receiveAsFlow()) + val received = transport.asMessageChannel() + transport.start() + transport.expectState(Transport.State.STARTED) + + inputChannel.send("not json at all") + inputChannel.send("") + inputChannel.send("""{"jsonrpc":"2.0","method":"after-garbage","id":1}""") + + val message = withTimeout(1.seconds) { received.receive() } + assertTrue(message is JsonRpcRequest) + assertEquals(MethodName("after-garbage"), message.method) + } + + @Test + fun `input flow completing drives state to CLOSED`(): Unit = runBlocking { + // Flow that emits one valid message then completes naturally. + val transport = makeTransport( + input = flowOf("""{"jsonrpc":"2.0","method":"once"}"""), + ) + val received = transport.asMessageChannel() + transport.start() + + val message = withTimeout(1.seconds) { received.receive() } + assertNotNull(message) + + transport.expectState(Transport.State.CLOSED, message = "completion of input flow should close transport") + } + + @Test + fun `input flow throwing drives state to CLOSED and fires onError`(): Unit = runBlocking { + val sentinel = IllegalStateException("upstream blew up") + val errors = mutableListOf() + val transport = makeTransport( + input = flow { throw sentinel }, + ).apply { onError { errors.add(it) } } + transport.start() + + transport.expectState(Transport.State.CLOSED, message = "input flow error should close transport") + assertTrue(errors.any { it === sentinel }, "expected upstream error to be reported via onError, got: $errors") + } + + @Test + fun `output IOException closes write loop without firing onError`(): Unit = runBlocking { + val errors = mutableListOf() + val transport = makeTransport( + output = { throw IOException("peer gone") }, + ).apply { onError { errors.add(it) } } + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.send(JsonRpcNotification(method = MethodName("ignored"))) + + transport.expectState(Transport.State.CLOSED, message = "IOException from output should close the transport") + assertTrue(errors.isEmpty(), "IOException should be treated as clean shutdown, got: $errors") + } + + @Test + fun `output unexpected exception fires onError and closes`(): Unit = runBlocking { + val sentinel = RuntimeException("writer bug") + val errors = mutableListOf() + val transport = makeTransport( + output = { throw sentinel }, + ).apply { onError { errors.add(it) } } + transport.start() + transport.expectState(Transport.State.STARTED) + + transport.send(JsonRpcNotification(method = MethodName("ignored"))) + + transport.expectState(Transport.State.CLOSED, message = "unexpected output error should close the transport") + assertTrue(errors.any { it === sentinel }, "expected writer error to be reported via onError, got: $errors") + } + + @Test + fun `parent scope cancellation drives state to CLOSED`(): Unit = runBlocking { + val transport = makeTransport( + input = Channel(Channel.UNLIMITED).receiveAsFlow(), + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + scope.cancel() + transport.expectState(Transport.State.CLOSED, message = "parent scope cancellation should close transport") + } + + @Test + fun `concurrent sends are all delivered to output`(): Unit = runBlocking { + val written = Channel(Channel.UNLIMITED) + val transport = makeTransport( + output = { line -> written.send(line) }, + ) + transport.start() + transport.expectState(Transport.State.STARTED) + + val sendJobs = (1..10).map { i -> + scope.launch { transport.send(JsonRpcNotification(method = MethodName("method$i"))) } + } + sendJobs.joinAll() + + val received = withTimeout(1.seconds) { List(10) { written.receive() } } + (1..10).forEach { i -> + assertTrue(received.any { it.contains("\"method\":\"method$i\"") }, "missing method$i in $received") + } + } + + private fun makeTransport( + input: Flow = Channel(Channel.UNLIMITED).receiveAsFlow(), + output: suspend (String) -> Unit = { /* discard */ }, + ): StdioTransport = StdioTransport( + parentScope = scope, + ioDispatcher = Dispatchers.IO, + input = input, + output = output, + ) + + private suspend fun Transport.expectState(state: Transport.State, timeout: Duration = 1.seconds, message: String? = null) { + val observed = mutableListOf() + try { + withTimeout(timeout) { + this@expectState.state + .onEach { observed.add(it) } + .first { it == state } + } + } catch (_: TimeoutCancellationException) { + fail("Timed out waiting for state $state after $timeout, observed: ${observed.joinToString { it.name }}${message?.let { " — $it" } ?: ""}") + } + } +}