Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions acp/api/acp.api
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ public abstract class com/agentclientprotocol/transport/BaseTransport : com/agen
}

public final class com/agentclientprotocol/transport/StdioTransport : com/agentclientprotocol/transport/BaseTransport {
public fun <init> (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Ljava/lang/String;)V
public synthetic fun <init> (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/io/Source;Lkotlinx/io/Sink;Ljava/lang/String;)V
public synthetic fun <init> (Lkotlinx/coroutines/CoroutineScope;Lkotlinx/coroutines/CoroutineDispatcher;Lkotlinx/io/Source;Lkotlinx/io/Sink;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun close ()V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import com.agentclientprotocol.rpc.ACPJson
import com.agentclientprotocol.rpc.JsonRpcMessage
import com.agentclientprotocol.rpc.decodeJsonRpcMessage
import com.agentclientprotocol.transport.Transport.State
import com.agentclientprotocol.util.checkCancelled
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.getAndUpdate
import kotlinx.coroutines.flow.update
import kotlinx.io.*
import kotlinx.serialization.encodeToString

Expand All @@ -21,47 +21,94 @@ private val logger = KotlinLogging.logger {}
* This transport communicates over standard input/output streams,
* which is commonly used for command-line agents.
*/
public class StdioTransport(
public class StdioTransport private constructor(
Comment thread
anna239 marked this conversation as resolved.
private val parentScope: CoroutineScope,
private val ioDispatcher: CoroutineDispatcher,
private val input: Source,
private val output: Sink,
private val name: String = StdioTransport::class.simpleName!!,
private val input: Flow<String>,
private val output: suspend (String) -> Unit,
private val name: String,
private val closeHandler: () -> Unit,
) : BaseTransport() {

/**
* Primary [Flow]-based constructor.
*
* @param parentScope coroutine scope for the transport's lifecycle
* @param ioDispatcher dispatcher used for the read and write coroutines
* @param input cold flow of incoming NDJSON lines. Cancellation of the
* transport cancels collection; use [Flow.onCompletion] to react to
* that cancellation if you need to release upstream resources.
* @param output suspending writer invoked once per outgoing line. The
* implementation owns framing (newline) and flushing semantics. To
* signal that the underlying transport has closed cleanly, throw
* [IllegalStateException] or [IOException]; the write loop will exit
* without firing an error. Any other exception is reported via
* [Transport.onError].
* @param name optional name used in coroutine names and log messages
*/
public constructor(
parentScope: CoroutineScope,
ioDispatcher: CoroutineDispatcher,
input: Flow<String>,
output: suspend (String) -> Unit,
name: String = StdioTransport::class.simpleName!!,
) : this(parentScope, ioDispatcher, input, output, name, closeHandler = {})

/**
* Back-compat constructor backed by blocking [Source] / [Sink].
*
* Deprecated because this variant blocks the dispatcher thread for the
* duration of each [Source.readLine]. Under high agent concurrency this can
* saturate the I/O dispatcher (e.g. `Dispatchers.IO`) and cascade into
* freezes if other consumers schedule blocking work on the same pool.
*
* When this constructor is removed, the [Flow]-based constructor should be
* promoted to the primary one and [closeHandler] dropped from its parameter
* list (cancelling [childScope] in [close] is already enough to unwind the
* Flow-based read path).
*
* Callers should adapt their blocking [Source] / [Sink] into a
* [Flow]`<String>` and a `suspend (String) -> Unit` at the call site and
* use the [Flow]-based constructor instead.
*/
@Deprecated(
message = "Blocking Source/Sink pins a dispatcher thread per read and forces an extra closeHandler. " +
"Adapt your I/O into Flow<String> + suspend (String) -> Unit and use the Flow-based constructor.",
level = DeprecationLevel.WARNING,
)
public constructor(
parentScope: CoroutineScope,
ioDispatcher: CoroutineDispatcher,
input: Source,
output: Sink,
name: String = StdioTransport::class.simpleName!!,
) : this(
parentScope, ioDispatcher,
sourceAsLineFlow(input), sinkAsLineWriter(output),
name,
closeHandler = makeSourceSinkCloseHandler(input, output),
)

private val childScope = CoroutineScope(parentScope.coroutineContext + SupervisorJob(parentScope.coroutineContext[Job]) + CoroutineName(name))

private val receiveChannel = Channel<JsonRpcMessage>(Channel.UNLIMITED)
private val sendChannel = Channel<JsonRpcMessage>(Channel.UNLIMITED)

override fun start() {
if (_state.getAndUpdate { State.STARTING } != State.CREATED) error("Transport is not in ${State.CREATED.name} state")
// Start reading messages from input
childScope.launch(CoroutineName("$name.join-jobs")) {
val readJob = launch(ioDispatcher + CoroutineName("$name.read-from-input")) {
try {
while (currentCoroutineContext().isActive) {
// ACP assumes working with ND Json (new line delimited Json) when working over stdio
input.collect { line ->
currentCoroutineContext().ensureActive()
// ACP assumes working with ND Json (new line delimited Json) when working over stdio
val line = try {
input.readLine()
} catch (e: IllegalStateException) {
logger.trace(e) { "Input stream closed" }
break
} catch (e: IOException) {
logger.trace(e) { "Input stream likely closed" }
break
}
if (line == null) {
// End of stream
logger.trace { "End of stream" }
break
}

val jsonRpcMessage = try {
decodeJsonRpcMessage(line)
} catch (t: Throwable) {
logger.trace(t) { "Failed to decode JSON message: $line" }
continue
return@collect
}
logger.trace { "Sending message to channel: $jsonRpcMessage" }
fireMessage(jsonRpcMessage)
Expand All @@ -84,9 +131,7 @@ public class StdioTransport(
for (message in sendChannel) {
val encoded = ACPJson.encodeToString(message)
try {
output.writeString(encoded)
output.writeString("\n")
output.flush()
output(encoded)
} catch (e: IllegalStateException) {
logger.trace(e) { "Output stream closed" }
break
Expand Down Expand Up @@ -129,7 +174,7 @@ public class StdioTransport(
}
}
}

override fun send(message: JsonRpcMessage) {
logger.trace { "Sending message: $message" }
val channelResult = sendChannel.trySend(message)
Expand All @@ -150,7 +195,43 @@ public class StdioTransport(
if (sendChannel.close()) logger.trace { "Send channel closed" }
if (receiveChannel.close()) logger.trace { "Receive channel closed" }

runCatching { input.close() }.onFailure { logger.warn(it) { "Exception when closing input stream" } }
runCatching { output.close() }.onFailure { logger.warn(it) { "Exception when closing output stream" } }
runCatching { closeHandler() }.onFailure { logger.warn(it) { "Exception in close handler" } }

// Unwind the read/write coroutines. The Source/Sink back-compat path relies
// on [closeHandler] closing the underlying streams to unblock readLine, but
// the Flow-based path has no equivalent unblock — without cancelling here
// the read job would stay parked inside input.collect and the transport
// would be stuck in CLOSING.
childScope.cancel()
}
}
}

private fun sourceAsLineFlow(source: Source): Flow<String> = flow {
while (true) {
val line = try {
source.readLine()
} catch (e: IllegalStateException) {
logger.trace(e) { "Input stream closed" }
break
} catch (e: IOException) {
logger.trace(e) { "Input stream likely closed" }
break
}
if (line == null) {
logger.trace { "End of stream" }
break
}
emit(line)
}
}

private fun sinkAsLineWriter(sink: Sink): suspend (String) -> Unit = { line ->
sink.writeString(line)
sink.writeString("\n")
sink.flush()
}

private fun makeSourceSinkCloseHandler(source: Source, sink: Sink): () -> Unit = {
runCatching { source.close() }.onFailure { logger.warn(it) { "Exception when closing input stream" } }
runCatching { sink.close() }.onFailure { logger.warn(it) { "Exception when closing output stream" } }
}
Loading
Loading