diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 364ca3bc0..e3e30d5f9 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -70,6 +70,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private bool _disposed; private readonly int? _optionsPort; private readonly string? _optionsHost; + private readonly string? _effectiveConnectionToken; private int? _actualPort; private int? _negotiatedProtocolVersion; private List? _modelsCache; @@ -123,16 +124,20 @@ public CopilotClient(CopilotClientOptions? options = null) _options = options ?? new(); // Validate mutually exclusive options - if (!string.IsNullOrEmpty(_options.CliUrl) && _options.CliPath != null) + if (!string.IsNullOrEmpty(_options.CliUrl) && (_options.UseStdio == true || _options.CliPath != null)) { - throw new ArgumentException("CliUrl is mutually exclusive with CliPath"); + throw new ArgumentException("CliUrl is mutually exclusive with UseStdio and CliPath"); } - // When CliUrl is provided, disable UseStdio (we connect to an external server, not spawn one) + // When CliUrl is provided, force TCP mode (we connect to an external server, not spawn one) if (!string.IsNullOrEmpty(_options.CliUrl)) { _options.UseStdio = false; } + else + { + _options.UseStdio ??= true; + } // Validate auth options with external server if (!string.IsNullOrEmpty(_options.CliUrl) && (!string.IsNullOrEmpty(_options.GitHubToken) || _options.UseLoggedInUser != null)) @@ -140,6 +145,22 @@ public CopilotClient(CopilotClientOptions? options = null) throw new ArgumentException("GitHubToken and UseLoggedInUser cannot be used with CliUrl (external server manages its own auth)"); } + if (_options.TcpConnectionToken is not null) + { + if (_options.TcpConnectionToken.Length == 0) + { + throw new ArgumentException("TcpConnectionToken must be a non-empty string"); + } + if (_options.UseStdio == true) + { + throw new ArgumentException("TcpConnectionToken cannot be used with UseStdio = true"); + } + } + + var sdkSpawnsCli = _options.UseStdio == false && string.IsNullOrEmpty(_options.CliUrl); + _effectiveConnectionToken = _options.TcpConnectionToken + ?? (sdkSpawnsCli ? Guid.NewGuid().ToString() : null); + _logger = _options.Logger ?? NullLogger.Instance; _onListModels = _options.OnListModels; @@ -216,7 +237,7 @@ async Task StartCoreAsync(CancellationToken ct) else { // Child process (stdio or TCP) - var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _logger, ct); + var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _effectiveConnectionToken, _logger, ct); _actualPort = portOrNull; result = ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrBuffer, ct); } @@ -1124,10 +1145,23 @@ private void ConfigureSessionFsHandlers(CopilotSession session, Func( - connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken); + int? serverVersion; + try + { + var connectResponse = await InvokeRpcAsync( + connection.Rpc, "connect", [new ConnectRequest { Token = _effectiveConnectionToken }], connection.StderrBuffer, cancellationToken); + serverVersion = (int)connectResponse.ProtocolVersion; + } + catch (RemoteRpcException ex) when (ex.ErrorCode == RemoteRpcException.MethodNotFoundErrorCode) + { + // Legacy server without `connect`; fall back to `ping`. A token, if any, + // is silently dropped — the legacy server can't enforce one. + var pingResponse = await InvokeRpcAsync( + connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken); + serverVersion = pingResponse.ProtocolVersion; + } - if (!pingResponse.ProtocolVersion.HasValue) + if (!serverVersion.HasValue) { throw new InvalidOperationException( $"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " + @@ -1135,19 +1169,18 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio $"Please update your server to ensure compatibility."); } - var serverVersion = pingResponse.ProtocolVersion.Value; - if (serverVersion < MinProtocolVersion || serverVersion > maxVersion) + if (serverVersion.Value < MinProtocolVersion || serverVersion.Value > maxVersion) { throw new InvalidOperationException( $"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " + - $"but server reports version {serverVersion}. " + + $"but server reports version {serverVersion.Value}. " + $"Please update your SDK or server to ensure compatibility."); } - _negotiatedProtocolVersion = serverVersion; + _negotiatedProtocolVersion = serverVersion.Value; } - private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken) + private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, string? connectionToken, ILogger logger, CancellationToken cancellationToken) { // Use explicit path, COPILOT_CLI_PATH env var (from options.Environment or process env), or bundled CLI - no PATH fallback var envCliPath = options.Environment is not null && options.Environment.TryGetValue("COPILOT_CLI_PATH", out var envValue) ? envValue @@ -1165,7 +1198,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio args.AddRange(["--headless", "--no-auto-update", "--log-level", options.LogLevel]); - if (options.UseStdio) + if (options.UseStdio == true) { args.Add("--stdio"); } @@ -1199,7 +1232,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio FileName = fileName, Arguments = string.Join(" ", processArgs.Select(ProcessArgumentEscaper.Escape)), UseShellExecute = false, - RedirectStandardInput = options.UseStdio, + RedirectStandardInput = options.UseStdio == true, RedirectStandardOutput = true, RedirectStandardError = true, WorkingDirectory = options.Cwd, @@ -1223,6 +1256,11 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio startInfo.Environment["COPILOT_SDK_AUTH_TOKEN"] = options.GitHubToken; } + if (!string.IsNullOrEmpty(connectionToken)) + { + startInfo.Environment["COPILOT_CONNECTION_TOKEN"] = connectionToken; + } + // Set telemetry environment variables if configured if (options.Telemetry is { } telemetry) { @@ -1260,7 +1298,7 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio }, cancellationToken); var detectedLocalhostTcpPort = (int?)null; - if (!options.UseStdio) + if (options.UseStdio != true) { // Wait for port announcement using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); @@ -1326,7 +1364,7 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? Stream inputStream, outputStream; NetworkStream? networkStream = null; - if (_options.UseStdio) + if (_options.UseStdio == true) { if (cliProcess == null) { diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 60566e40a..295c146b9 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -46,6 +46,30 @@ internal sealed class PingRequest public string? Message { get; set; } } +/// RPC data type for Connect operations. +internal sealed class ConnectResult +{ + /// Always true on success. + [JsonPropertyName("ok")] + public bool Ok { get; set; } + + /// Server protocol version number. + [JsonPropertyName("protocolVersion")] + public long ProtocolVersion { get; set; } + + /// Server package version. + [JsonPropertyName("version")] + public string Version { get; set; } = string.Empty; +} + +/// RPC data type for Connect operations. +internal sealed class ConnectRequest +{ + /// Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN. + [JsonPropertyName("token")] + public string? Token { get; set; } +} + /// Billing information. public sealed class ModelBilling { @@ -3130,6 +3154,13 @@ public async Task PingAsync(string? message = null, CancellationToke return await CopilotClient.InvokeRpcAsync(_rpc, "ping", [request], cancellationToken); } + /// Calls "connect". + internal async Task ConnectAsync(string? token = null, CancellationToken cancellationToken = default) + { + var request = new ConnectRequest { Token = token }; + return await CopilotClient.InvokeRpcAsync(_rpc, "connect", [request], cancellationToken); + } + /// Models APIs. public ServerModelsApi Models { get; } @@ -4257,6 +4288,8 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func internal sealed class RemoteRpcException(string message, int errorCode, Exception? innerException = null) : Exception(message, innerException) { + /// JSON-RPC 2.0 reserved error code: requested method does not exist. + public const int MethodNotFoundErrorCode = -32601; + public int ErrorCode { get; } = errorCode; } diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 383a3fb1e..c68f7dee6 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -70,6 +70,7 @@ protected CopilotClientOptions(CopilotClientOptions? other) OnListModels = other.OnListModels; SessionFs = other.SessionFs; SessionIdleTimeoutSeconds = other.SessionIdleTimeoutSeconds; + TcpConnectionToken = other.TcpConnectionToken; } /// @@ -90,8 +91,11 @@ protected CopilotClientOptions(CopilotClientOptions? other) public int Port { get; set; } /// /// Whether to use stdio transport for communication with the CLI server. + /// Defaults to true when neither nor + /// switches the client into TCP mode. Setting this to true is mutually + /// exclusive with . /// - public bool UseStdio { get; set; } = true; + public bool? UseStdio { get; set; } /// /// URL of an existing CLI server to connect to instead of starting a new one. /// @@ -175,6 +179,13 @@ public string? GithubToken /// public int? SessionIdleTimeoutSeconds { get; set; } + /// + /// Connection token for the headless CLI server (TCP only). When the SDK spawns its own + /// CLI in TCP mode and this is omitted, a GUID is generated automatically so the loopback + /// listener is safe by default. Cannot be combined with = true. + /// + public string? TcpConnectionToken { get; set; } + /// /// Creates a shallow clone of this instance. /// diff --git a/dotnet/test/ConnectionTokenTests.cs b/dotnet/test/ConnectionTokenTests.cs new file mode 100644 index 000000000..499c9d36e --- /dev/null +++ b/dotnet/test/ConnectionTokenTests.cs @@ -0,0 +1,147 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System; +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; + +namespace GitHub.Copilot.SDK.Test; + +/// +/// Custom fixture that spawns a CLI in TCP mode with an explicit connection token, so +/// sibling clients can attempt to connect to the same port with the right/wrong/no token. +/// +public class ConnectionTokenTestFixture : IAsyncLifetime +{ + public E2ETestContext Ctx { get; private set; } = null!; + public CopilotClient GoodClient { get; private set; } = null!; + public int Port { get; private set; } + + public const string Token = "right-token"; + + public async Task InitializeAsync() + { + Ctx = await E2ETestContext.CreateAsync(); + GoodClient = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions + { + TcpConnectionToken = Token, + }); + + await GoodClient.StartAsync(); + Port = GoodClient.ActualPort + ?? throw new InvalidOperationException("GoodClient is not using TCP mode; ActualPort is null"); + } + + public async Task DisposeAsync() + { + if (GoodClient is not null) + { + await GoodClient.ForceStopAsync(); + } + + await Ctx.DisposeAsync(); + } +} + +public class ConnectionTokenTests : IClassFixture +{ + private readonly ConnectionTokenTestFixture _fixture; + + public ConnectionTokenTests(ConnectionTokenTestFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task Connects_With_The_Matching_Token() + { + var pong = await _fixture.GoodClient.PingAsync("hi"); + Assert.Equal("pong: hi", pong.Message); + } + + [Fact] + public async Task Rejects_A_Wrong_Token() + { + var wrongClient = new CopilotClient(new CopilotClientOptions + { + CliUrl = $"localhost:{_fixture.Port}", + TcpConnectionToken = "wrong", + }); + + try + { + var ex = await Assert.ThrowsAnyAsync(() => wrongClient.StartAsync()); + Assert.Contains("AUTHENTICATION_FAILED", GetFullMessage(ex)); + } + finally + { + // Best-effort cleanup; ignore stop errors when the client failed to start. + try { await wrongClient.ForceStopAsync(); } catch (Exception) { } + } + } + + [Fact] + public async Task Rejects_A_Missing_Token_When_One_Is_Required() + { + var noTokenClient = new CopilotClient(new CopilotClientOptions + { + CliUrl = $"localhost:{_fixture.Port}", + }); + + try + { + var ex = await Assert.ThrowsAnyAsync(() => noTokenClient.StartAsync()); + Assert.Contains("AUTHENTICATION_FAILED", GetFullMessage(ex)); + } + finally + { + // Best-effort cleanup; ignore stop errors when the client failed to start. + try { await noTokenClient.ForceStopAsync(); } catch (Exception) { } + } + } + + private static string GetFullMessage(Exception ex) + { + var messages = new List(); + for (var cur = ex; cur is not null; cur = cur.InnerException) + { + messages.Add(cur.Message); + } + return string.Join(" | ", messages); + } +} + +/// +/// When the SDK spawns its own CLI in TCP mode without an explicit token, it auto-generates +/// a GUID and round-trips it through the spawned CLI. +/// +public class ConnectionTokenAutoGeneratedTests : IAsyncLifetime +{ + private E2ETestContext _ctx = null!; + private CopilotClient _client = null!; + + public async Task InitializeAsync() + { + _ctx = await E2ETestContext.CreateAsync(); + _client = _ctx.CreateClient(useStdio: false); + } + + public async Task DisposeAsync() + { + if (_client is not null) + { + await _client.ForceStopAsync(); + } + + await _ctx.DisposeAsync(); + } + + [Fact] + public async Task The_SDK_Auto_Generated_Guid_Round_Trips_Through_The_Spawned_CLI() + { + await _client.StartAsync(); + var pong = await _client.PingAsync("hi"); + Assert.Equal("pong: hi", pong.Message); + } +} diff --git a/dotnet/test/E2E/ClientOptionsE2ETests.cs b/dotnet/test/E2E/ClientOptionsE2ETests.cs index bdbc57470..e1e009c4d 100644 --- a/dotnet/test/E2E/ClientOptionsE2ETests.cs +++ b/dotnet/test/E2E/ClientOptionsE2ETests.cs @@ -337,6 +337,11 @@ function handleMessage(message) { requests.push({ method: message.method, params: message.params }); saveCapture(); + if (message.method === "connect") { + writeResponse(message.id, { ok: true, protocolVersion: 3, version: "fake" }); + return; + } + if (message.method === "ping") { writeResponse(message.id, { message: "pong", protocolVersion: 3 }); return; diff --git a/dotnet/test/E2E/MultiClientCommandsElicitationE2ETests.cs b/dotnet/test/E2E/MultiClientCommandsElicitationE2ETests.cs index be1221848..4e3711650 100644 --- a/dotnet/test/E2E/MultiClientCommandsElicitationE2ETests.cs +++ b/dotnet/test/E2E/MultiClientCommandsElicitationE2ETests.cs @@ -18,10 +18,15 @@ public class MultiClientCommandsElicitationFixture : IAsyncLifetime public E2ETestContext Ctx { get; private set; } = null!; public CopilotClient Client1 { get; private set; } = null!; + public const string SharedToken = "multi-client-cmd-shared-token"; + public async Task InitializeAsync() { Ctx = await E2ETestContext.CreateAsync(); - Client1 = Ctx.CreateClient(useStdio: false); + Client1 = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions + { + TcpConnectionToken = SharedToken, + }); } public async Task DisposeAsync() @@ -80,6 +85,7 @@ public async Task InitializeAsync() _client2 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientCommandsElicitationFixture.SharedToken, }); } @@ -221,6 +227,7 @@ public async Task Capabilities_Changed_Fires_When_Elicitation_Provider_Disconnec _client3 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientCommandsElicitationFixture.SharedToken, }); // Client3 joins WITH elicitation handler diff --git a/dotnet/test/E2E/MultiClientE2ETests.cs b/dotnet/test/E2E/MultiClientE2ETests.cs index 115e13e96..e5d7c4b69 100644 --- a/dotnet/test/E2E/MultiClientE2ETests.cs +++ b/dotnet/test/E2E/MultiClientE2ETests.cs @@ -21,10 +21,15 @@ public class MultiClientTestFixture : IAsyncLifetime public E2ETestContext Ctx { get; private set; } = null!; public CopilotClient Client1 { get; private set; } = null!; + public const string SharedToken = "multi-client-shared-token"; + public async Task InitializeAsync() { Ctx = await E2ETestContext.CreateAsync(); - Client1 = Ctx.CreateClient(useStdio: false); + Client1 = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions + { + TcpConnectionToken = SharedToken, + }); } public async Task DisposeAsync() @@ -78,6 +83,7 @@ public async Task InitializeAsync() _client2 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientTestFixture.SharedToken, }); } @@ -336,6 +342,7 @@ public async Task Disconnecting_Client_Removes_Its_Tools() _client2 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientTestFixture.SharedToken, }); // Now only stable_tool should be available diff --git a/dotnet/test/E2E/PendingWorkResumeE2ETests.cs b/dotnet/test/E2E/PendingWorkResumeE2ETests.cs index a6d511eda..fa654e7a0 100644 --- a/dotnet/test/E2E/PendingWorkResumeE2ETests.cs +++ b/dotnet/test/E2E/PendingWorkResumeE2ETests.cs @@ -15,6 +15,7 @@ public class PendingWorkResumeE2ETests(E2ETestFixture fixture, ITestOutputHelper : E2ETestBase(fixture, "pending_work_resume", output) { private static readonly TimeSpan PendingWorkTimeout = TimeSpan.FromSeconds(60); + private const string SharedToken = "pending-work-resume-shared-token"; [Fact] public async Task Should_Continue_Pending_Permission_Request_After_Resume() @@ -23,11 +24,11 @@ public async Task Should_Continue_Pending_Permission_Request_After_Resume() var releaseOriginalPermission = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var resumedToolInvoked = false; - await using var server = Ctx.CreateClient(useStdio: false); + await using var server = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions { TcpConnectionToken = SharedToken }); await server.StartAsync(); var cliUrl = GetCliUrl(server); - using var suspendedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + using var suspendedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var session1 = await suspendedClient.CreateSessionAsync(new SessionConfig { Tools = [AIFunctionFactory.Create(ResumePermissionTool, "resume_permission_tool")], @@ -54,7 +55,7 @@ await session1.SendAsync(new MessageOptions await suspendedClient.ForceStopAsync(); - await using var resumedTcpClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + await using var resumedTcpClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var session2 = await resumedTcpClient.ResumeSessionAsync(sessionId, new ResumeSessionConfig { ContinuePendingWork = true, @@ -106,11 +107,11 @@ public async Task Should_Continue_Pending_External_Tool_Request_After_Resume() var originalToolStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseOriginalTool = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - await using var server = Ctx.CreateClient(useStdio: false); + await using var server = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions { TcpConnectionToken = SharedToken }); await server.StartAsync(); var cliUrl = GetCliUrl(server); - using var suspendedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + using var suspendedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var session1 = await suspendedClient.CreateSessionAsync(new SessionConfig { Tools = [AIFunctionFactory.Create(BlockingExternalTool, "resume_external_tool")], @@ -131,7 +132,7 @@ await session1.SendAsync(new MessageOptions Assert.Equal("beta", await originalToolStarted.Task.WaitAsync(PendingWorkTimeout)); await suspendedClient.ForceStopAsync(); - await using var resumedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + await using var resumedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var session2 = await resumedClient.ResumeSessionAsync(sessionId, new ResumeSessionConfig { ContinuePendingWork = true, @@ -171,11 +172,11 @@ public async Task Should_Continue_Parallel_Pending_External_Tool_Requests_After_ var releaseOriginalToolA = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseOriginalToolB = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - await using var server = Ctx.CreateClient(useStdio: false); + await using var server = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions { TcpConnectionToken = SharedToken }); await server.StartAsync(); var cliUrl = GetCliUrl(server); - using var suspendedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + using var suspendedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var session1 = await suspendedClient.CreateSessionAsync(new SessionConfig { Tools = @@ -205,7 +206,7 @@ await Task.WhenAll( await suspendedClient.ForceStopAsync(); - await using var resumedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + await using var resumedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var session2 = await resumedClient.ResumeSessionAsync(sessionId, new ResumeSessionConfig { ContinuePendingWork = true, @@ -256,12 +257,12 @@ async Task BlockingToolB([Description("Value to look up")] string value) [Fact] public async Task Should_Resume_Successfully_When_No_Pending_Work_Exists() { - await using var server = Ctx.CreateClient(useStdio: false); + await using var server = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions { TcpConnectionToken = SharedToken }); await server.StartAsync(); var cliUrl = GetCliUrl(server); string sessionId; - await using (var firstClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl })) + await using (var firstClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken })) { var firstSession = await firstClient.CreateSessionAsync(new SessionConfig { @@ -275,7 +276,7 @@ public async Task Should_Resume_Successfully_When_No_Pending_Work_Exists() await firstSession.DisposeAsync(); } - await using var resumedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + await using var resumedClient = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = SharedToken }); var resumedSession = await resumedClient.ResumeSessionAsync(sessionId, new ResumeSessionConfig { ContinuePendingWork = true, diff --git a/dotnet/test/E2E/SessionFsE2ETests.cs b/dotnet/test/E2E/SessionFsE2ETests.cs index 540a918b8..a9475c664 100644 --- a/dotnet/test/E2E/SessionFsE2ETests.cs +++ b/dotnet/test/E2E/SessionFsE2ETests.cs @@ -94,7 +94,7 @@ public async Task Should_Reject_SetProvider_When_Sessions_Already_Exist() var providerRoot = CreateProviderRoot(); try { - await using var client1 = CreateSessionFsClient(providerRoot, useStdio: false); + await using var client1 = CreateSessionFsClient(providerRoot, useStdio: false, tcpConnectionToken: "session-fs-shared-token"); var createSessionFsHandler = (Func)(s => new TestSessionFsHandler(s.SessionId, providerRoot)); _ = await client1.CreateSessionAsync(new SessionConfig @@ -113,6 +113,7 @@ public async Task Should_Reject_SetProvider_When_Sessions_Already_Exist() CliUrl = $"localhost:{port}", LogLevel = "error", SessionFs = SessionFsConfig, + TcpConnectionToken = "session-fs-shared-token", }); try @@ -446,7 +447,7 @@ public async Task Should_Persist_Plan_Md_Via_SessionFs() } } - private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = true) + private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = true, string? tcpConnectionToken = null) { Directory.CreateDirectory(providerRoot); return Ctx.CreateClient( @@ -454,6 +455,7 @@ private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = options: new CopilotClientOptions { SessionFs = SessionFsConfig, + TcpConnectionToken = tcpConnectionToken, }); } diff --git a/dotnet/test/E2E/SuspendE2ETests.cs b/dotnet/test/E2E/SuspendE2ETests.cs index 4759245b9..af9d8284f 100644 --- a/dotnet/test/E2E/SuspendE2ETests.cs +++ b/dotnet/test/E2E/SuspendE2ETests.cs @@ -50,12 +50,13 @@ public async Task Should_Suspend_Idle_Session_Without_Throwing() [Fact] public async Task Should_Allow_Resume_And_Continue_Conversation_After_Suspend() { - await using var server = Ctx.CreateClient(useStdio: false); + const string sharedToken = "suspend-shared-token"; + await using var server = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions { TcpConnectionToken = sharedToken }); await server.StartAsync(); var cliUrl = GetCliUrl(server); string sessionId; - await using (var client1 = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl })) + await using (var client1 = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = sharedToken })) { var session1 = await client1.CreateSessionAsync(new SessionConfig { @@ -76,7 +77,7 @@ await session1.SendAndWaitAsync(new MessageOptions // A different client should be able to pick the session back up. The previous // turn was completed before suspend, so there is no pending work to continue. - await using var client2 = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl }); + await using var client2 = Ctx.CreateClient(options: new CopilotClientOptions { CliUrl = cliUrl, TcpConnectionToken = sharedToken }); var session2 = await client2.ResumeSessionAsync(sessionId, new ResumeSessionConfig { OnPermissionRequest = PermissionHandler.ApproveAll, diff --git a/dotnet/test/Harness/E2ETestContext.cs b/dotnet/test/Harness/E2ETestContext.cs index 177c20009..bc431b31f 100644 --- a/dotnet/test/Harness/E2ETestContext.cs +++ b/dotnet/test/Harness/E2ETestContext.cs @@ -169,7 +169,7 @@ public IReadOnlyDictionary GetEnvironment() return env!; } - public CopilotClient CreateClient(bool useStdio = true, CopilotClientOptions? options = null, bool autoInjectGitHubToken = true) + public CopilotClient CreateClient(bool? useStdio = null, CopilotClientOptions? options = null, bool autoInjectGitHubToken = true) { options ??= new CopilotClientOptions(); diff --git a/go/client.go b/go/client.go index c446e66d8..686d74be0 100644 --- a/go/client.go +++ b/go/client.go @@ -113,11 +113,18 @@ type Client struct { processErrorPtr *error osProcess atomic.Pointer[os.Process] negotiatedProtocolVersion int - onListModels func(ctx context.Context) ([]ModelInfo, error) + // effectiveConnectionToken is the token sent in `connect`; auto-generated when + // the SDK spawns its own CLI in TCP mode. + effectiveConnectionToken string + onListModels func(ctx context.Context) ([]ModelInfo, error) // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). RPC *rpc.ServerRpc + + // internalRPC provides SDK-internal RPC methods (handshake helpers etc.). + // Lowercase = not exported; external callers cannot reach it. + internalRPC *rpc.InternalServerRpc } // NewClient creates a new Copilot CLI client with the given options. @@ -164,6 +171,11 @@ func NewClient(options *ClientOptions) *Client { panic("GitHubToken and UseLoggedInUser cannot be used with CLIUrl (external server manages its own auth)") } + // Validate token vs stdio + if options.TCPConnectionToken != "" && options.UseStdio != nil && *options.UseStdio { + panic("TCPConnectionToken cannot be used with UseStdio: true") + } + // Parse CLIUrl if provided if options.CLIUrl != "" { host, port := parseCliUrl(options.CLIUrl) @@ -234,6 +246,14 @@ func NewClient(options *ClientOptions) *Client { } } + // Resolve the effective connection token: explicit value if set; else if the SDK + // spawns its own CLI in TCP mode, generate a UUID; otherwise empty. + if options != nil && options.TCPConnectionToken != "" { + client.effectiveConnectionToken = options.TCPConnectionToken + } else if !client.useStdio && !client.isExternalServer { + client.effectiveConnectionToken = uuid.NewString() + } + client.options = opts return client } @@ -426,6 +446,7 @@ func (c *Client) Stop() error { } c.RPC = nil + c.internalRPC = nil return errors.Join(errs...) } @@ -497,6 +518,7 @@ func (c *Client) ForceStop() { } c.RPC = nil + c.internalRPC = nil } func (c *Client) ensureConnected(ctx context.Context) error { @@ -1331,25 +1353,49 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { // minProtocolVersion is the minimum protocol version this SDK can communicate with. const minProtocolVersion = 2 -// verifyProtocolVersion verifies that the server's protocol version is within the supported range -// and stores the negotiated version. +// verifyProtocolVersion sends the `connect` handshake (carrying the optional token) and +// verifies the server's protocol version. Falls back to `ping` against legacy servers +// that don't implement `connect`. func (c *Client) verifyProtocolVersion(ctx context.Context) error { + if c.client == nil { + return fmt.Errorf("client not connected") + } maxVersion := GetSdkProtocolVersion() - pingResult, err := c.Ping(ctx, "") + + var serverVersion *int + tokenPtr := (*string)(nil) + if c.effectiveConnectionToken != "" { + t := c.effectiveConnectionToken + tokenPtr = &t + } + connectResult, err := c.internalRPC.Connect(ctx, &rpc.ConnectRequest{Token: tokenPtr}) if err != nil { - return err + var rpcErr *jsonrpc2.Error + if errors.As(err, &rpcErr) && rpcErr.Code == jsonrpc2.ErrMethodNotFound.Code { + // Legacy server without `connect`; fall back to `ping`. A token, if any, + // is silently dropped — the legacy server can't enforce one. + pingResult, perr := c.Ping(ctx, "") + if perr != nil { + return perr + } + serverVersion = pingResult.ProtocolVersion + } else { + return err + } + } else { + v := int(connectResult.ProtocolVersion) + serverVersion = &v } - if pingResult.ProtocolVersion == nil { + if serverVersion == nil { return fmt.Errorf("SDK protocol version mismatch: SDK supports versions %d-%d, but server does not report a protocol version. Please update your server to ensure compatibility", minProtocolVersion, maxVersion) } - serverVersion := *pingResult.ProtocolVersion - if serverVersion < minProtocolVersion || serverVersion > maxVersion { - return fmt.Errorf("SDK protocol version mismatch: SDK supports versions %d-%d, but server reports version %d. Please update your SDK or server to ensure compatibility", minProtocolVersion, maxVersion, serverVersion) + if *serverVersion < minProtocolVersion || *serverVersion > maxVersion { + return fmt.Errorf("SDK protocol version mismatch: SDK supports versions %d-%d, but server reports version %d. Please update your SDK or server to ensure compatibility", minProtocolVersion, maxVersion, *serverVersion) } - c.negotiatedProtocolVersion = serverVersion + c.negotiatedProtocolVersion = *serverVersion return nil } @@ -1422,6 +1468,10 @@ func (c *Client) startCLIServer(ctx context.Context) error { c.process.Env = append(c.process.Env, "COPILOT_SDK_AUTH_TOKEN="+c.options.GitHubToken) } + if c.effectiveConnectionToken != "" { + c.process.Env = append(c.process.Env, "COPILOT_CONNECTION_TOKEN="+c.effectiveConnectionToken) + } + if c.options.Telemetry != nil { t := c.options.Telemetry c.process.Env = append(c.process.Env, "COPILOT_OTEL_ENABLED=true") @@ -1477,6 +1527,7 @@ func (c *Client) startCLIServer(ctx context.Context) error { }() }) c.RPC = rpc.NewServerRpc(c.client) + c.internalRPC = rpc.NewInternalServerRpc(c.client) c.setupNotificationHandler() c.client.Start() @@ -1602,6 +1653,7 @@ func (c *Client) connectViaTcp(ctx context.Context) error { }() }) c.RPC = rpc.NewServerRpc(c.client) + c.internalRPC = rpc.NewInternalServerRpc(c.client) c.setupNotificationHandler() c.client.Start() diff --git a/go/internal/e2e/client_options_e2e_test.go b/go/internal/e2e/client_options_e2e_test.go index 12f331530..0b06470f1 100644 --- a/go/internal/e2e/client_options_e2e_test.go +++ b/go/internal/e2e/client_options_e2e_test.go @@ -446,6 +446,10 @@ function handleMessage(message) { } requests.push({ method: message.method, params: message.params }); saveCapture(); + if (message.method === "connect") { + writeResponse(message.id, { ok: true, protocolVersion: 3, version: "fake" }); + return; + } if (message.method === "ping") { writeResponse(message.id, { message: "pong", protocolVersion: 3, timestamp: Date.now() }); return; diff --git a/go/internal/e2e/commands_and_elicitation_e2e_test.go b/go/internal/e2e/commands_and_elicitation_e2e_test.go index 5b2f340d5..3ae14d649 100644 --- a/go/internal/e2e/commands_and_elicitation_e2e_test.go +++ b/go/internal/e2e/commands_and_elicitation_e2e_test.go @@ -15,6 +15,7 @@ func TestCommandsE2E(t *testing.T) { ctx := testharness.NewTestContext(t) client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { client1.ForceStop() }) @@ -33,7 +34,8 @@ func TestCommandsE2E(t *testing.T) { } client2 := copilot.NewClient(&copilot.ClientOptions{ - CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + TCPConnectionToken: sharedTcpToken, }) t.Cleanup(func() { client2.ForceStop() }) @@ -509,6 +511,7 @@ func TestUIElicitationMultiClientE2E(t *testing.T) { ctx := testharness.NewTestContext(t) client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { client1.ForceStop() }) @@ -558,7 +561,8 @@ func TestUIElicitationMultiClientE2E(t *testing.T) { // Client2 joins with elicitation handler — should trigger capabilities.changed client2 := copilot.NewClient(&copilot.ClientOptions{ - CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + TCPConnectionToken: sharedTcpToken, }) session2, err := client2.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, @@ -620,7 +624,8 @@ func TestUIElicitationMultiClientE2E(t *testing.T) { // Client3 (dedicated for this test) joins with elicitation handler client3 := copilot.NewClient(&copilot.ClientOptions{ - CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + TCPConnectionToken: sharedTcpToken, }) _, err = client3.ResumeSession(t.Context(), session1.SessionID, &copilot.ResumeSessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, diff --git a/go/internal/e2e/connection_token_test.go b/go/internal/e2e/connection_token_test.go new file mode 100644 index 000000000..269c5ae5a --- /dev/null +++ b/go/internal/e2e/connection_token_test.go @@ -0,0 +1,114 @@ +package e2e + +import ( + "fmt" + "strings" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +func TestConnectionToken(t *testing.T) { + t.Run("explicit token round-trips successfully", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = "right-token" + }) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Start failed: %v", err) + } + + resp, err := client.Ping(t.Context(), "hi") + if err != nil { + t.Fatalf("Ping failed: %v", err) + } + if resp.Message != "pong: hi" { + t.Errorf("expected message 'pong: hi', got %q", resp.Message) + } + }) + + t.Run("auto-generated token round-trips successfully", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + }) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Start failed: %v", err) + } + + resp, err := client.Ping(t.Context(), "hi") + if err != nil { + t.Fatalf("Ping failed: %v", err) + } + if resp.Message != "pong: hi" { + t.Errorf("expected message 'pong: hi', got %q", resp.Message) + } + }) + + t.Run("sibling client with wrong token is rejected", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + good := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = "right-token" + }) + t.Cleanup(func() { good.ForceStop() }) + + if err := good.Start(t.Context()); err != nil { + t.Fatalf("good client Start failed: %v", err) + } + port := good.ActualPort() + if port == 0 { + t.Fatalf("expected non-zero port from TCP mode client") + } + + bad := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", port), + TCPConnectionToken: "wrong", + }) + t.Cleanup(func() { bad.ForceStop() }) + + err := bad.Start(t.Context()) + if err == nil { + t.Fatalf("expected sibling client with wrong token to fail") + } + if !strings.Contains(err.Error(), "AUTHENTICATION_FAILED") { + t.Errorf("expected AUTHENTICATION_FAILED error, got: %v", err) + } + }) + + t.Run("sibling client with no token is rejected", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + good := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = "right-token" + }) + t.Cleanup(func() { good.ForceStop() }) + + if err := good.Start(t.Context()); err != nil { + t.Fatalf("good client Start failed: %v", err) + } + port := good.ActualPort() + if port == 0 { + t.Fatalf("expected non-zero port from TCP mode client") + } + + none := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", port), + }) + t.Cleanup(func() { none.ForceStop() }) + + err := none.Start(t.Context()) + if err == nil { + t.Fatalf("expected sibling client with no token to fail") + } + if !strings.Contains(err.Error(), "AUTHENTICATION_FAILED") { + t.Errorf("expected AUTHENTICATION_FAILED error, got: %v", err) + } + }) +} diff --git a/go/internal/e2e/multi_client_e2e_test.go b/go/internal/e2e/multi_client_e2e_test.go index 71721b69a..4426912c5 100644 --- a/go/internal/e2e/multi_client_e2e_test.go +++ b/go/internal/e2e/multi_client_e2e_test.go @@ -18,6 +18,7 @@ func TestMultiClientE2E(t *testing.T) { ctx := testharness.NewTestContext(t) client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { client1.ForceStop() }) @@ -36,7 +37,8 @@ func TestMultiClientE2E(t *testing.T) { } client2 := copilot.NewClient(&copilot.ClientOptions{ - CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + TCPConnectionToken: sharedTcpToken, }) t.Cleanup(func() { client2.ForceStop() }) @@ -475,7 +477,8 @@ func TestMultiClientE2E(t *testing.T) { // Recreate client2 for cleanup (but don't rejoin the session) client2 = copilot.NewClient(&copilot.ClientOptions{ - CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + TCPConnectionToken: sharedTcpToken, }) // Now only stable_tool should be available diff --git a/go/internal/e2e/pending_work_resume_e2e_test.go b/go/internal/e2e/pending_work_resume_e2e_test.go index c52f6e588..aa1786f66 100644 --- a/go/internal/e2e/pending_work_resume_e2e_test.go +++ b/go/internal/e2e/pending_work_resume_e2e_test.go @@ -45,6 +45,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { suspendedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) session1, err := suspendedClient.CreateSession(t.Context(), &copilot.SessionConfig{ Tools: []copilot.Tool{originalTool}, @@ -111,6 +112,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { resumedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { resumedClient.ForceStop() }) @@ -189,6 +191,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { suspendedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) session1, err := suspendedClient.CreateSession(t.Context(), &copilot.SessionConfig{ Tools: []copilot.Tool{originalTool}, @@ -226,6 +229,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { resumedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { resumedClient.ForceStop() }) @@ -301,6 +305,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { suspendedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) session1, err := suspendedClient.CreateSession(t.Context(), &copilot.SessionConfig{ Tools: []copilot.Tool{originalA, originalB}, @@ -345,6 +350,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { resumedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { resumedClient.ForceStop() }) @@ -411,6 +417,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { firstClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) defer firstClient.ForceStop() @@ -438,6 +445,7 @@ func TestPendingWorkResumeE2E(t *testing.T) { resumedClient := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { resumedClient.ForceStop() }) @@ -475,11 +483,20 @@ func serverCliURL(t *testing.T, server *copilot.Client) string { return fmt.Sprintf("localhost:%d", port) } +// sharedTcpToken is the connection token used by startTcpServer and any sibling +// client that connects via the resulting CLI URL. Tests use a fixed token rather +// than the auto-generated one because the second client is constructed without +// access to the first client's internal state. +const sharedTcpToken = "tcp-shared-test-token" + // startTcpServer starts a TCP-mode server client and returns its CLI URL. // It triggers an initial connection so ActualPort is populated. func startTcpServer(t *testing.T, ctx *testharness.TestContext) (*copilot.Client, string) { t.Helper() - server := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.UseStdio = copilot.Bool(false) }) + server := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = sharedTcpToken + }) t.Cleanup(func() { server.ForceStop() }) // Trigger connection so we can read the port. CreateSession+Disconnect is the // established pattern (see multi_client_test.go). diff --git a/go/internal/e2e/suspend_e2e_test.go b/go/internal/e2e/suspend_e2e_test.go index 2c02d6090..3c70874a5 100644 --- a/go/internal/e2e/suspend_e2e_test.go +++ b/go/internal/e2e/suspend_e2e_test.go @@ -52,6 +52,7 @@ func TestSuspendE2E(t *testing.T) { client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { client1.ForceStop() }) @@ -77,6 +78,7 @@ func TestSuspendE2E(t *testing.T) { client2 := ctx.NewClient(func(opts *copilot.ClientOptions) { opts.CLIUrl = cliURL opts.CLIPath = "" + opts.TCPConnectionToken = sharedTcpToken }) t.Cleanup(func() { client2.ForceStop() }) diff --git a/go/rpc/generated_rpc.go b/go/rpc/generated_rpc.go index c1a8e488e..dd5ff61b8 100644 --- a/go/rpc/generated_rpc.go +++ b/go/rpc/generated_rpc.go @@ -26,6 +26,8 @@ type RPCTypes struct { AuthInfoType AuthInfoType `json:"AuthInfoType"` CommandsHandlePendingCommandRequest CommandsHandlePendingCommandRequest `json:"CommandsHandlePendingCommandRequest"` CommandsHandlePendingCommandResult CommandsHandlePendingCommandResult `json:"CommandsHandlePendingCommandResult"` + ConnectRequest ConnectRequest `json:"ConnectRequest"` + ConnectResult ConnectResult `json:"ConnectResult"` CurrentModel CurrentModel `json:"CurrentModel"` DiscoveredMCPServer DiscoveredMCPServer `json:"DiscoveredMcpServer"` DiscoveredMCPServerSource MCPServerSource `json:"DiscoveredMcpServerSource"` @@ -348,6 +350,22 @@ type CommandsHandlePendingCommandResult struct { Success bool `json:"success"` } +// Internal: ConnectRequest is an internal SDK API and is not part of the public surface. +type ConnectRequest struct { + // Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN + Token *string `json:"token,omitempty"` +} + +// Internal: ConnectResult is an internal SDK API and is not part of the public surface. +type ConnectResult struct { + // Always true on success + Ok bool `json:"ok"` + // Server protocol version number + ProtocolVersion int64 `json:"protocolVersion"` + // Server package version + Version string `json:"version"` +} + type CurrentModel struct { // Currently active model identifier ModelID *string `json:"modelId,omitempty"` @@ -2745,6 +2763,35 @@ func NewServerRpc(client *jsonrpc2.Client) *ServerRpc { return r } +type internalServerApi struct { + client *jsonrpc2.Client +} + +// InternalServerRpc provides internal SDK server-scoped RPC methods (handshake helpers etc.). Not part of the public API. +type InternalServerRpc struct { + common internalServerApi // Reuse a single struct instead of allocating one for each service on the heap. + +} + +// Internal: Connect is part of the SDK's internal handshake/plumbing; external callers should not use it. +func (a *InternalServerRpc) Connect(ctx context.Context, params *ConnectRequest) (*ConnectResult, error) { + raw, err := a.common.client.Request("connect", params) + if err != nil { + return nil, err + } + var result ConnectResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + +func NewInternalServerRpc(client *jsonrpc2.Client) *InternalServerRpc { + r := &InternalServerRpc{} + r.common = internalServerApi{client: client} + return r +} + type sessionApi struct { client *jsonrpc2.Client sessionID string diff --git a/go/types.go b/go/types.go index 73b039a30..bc5cecba5 100644 --- a/go/types.go +++ b/go/types.go @@ -30,6 +30,11 @@ type ClientOptions struct { // UseStdio controls whether to use stdio transport instead of TCP. // Default: nil (use default = true, i.e. stdio). Use Bool(false) to explicitly select TCP. UseStdio *bool + // TCPConnectionToken is the token sent in the `connect` handshake when using TCP transport. + // Only meaningful in TCP mode. When the SDK spawns its own CLI in TCP mode and this is + // empty, an auto-generated UUID is used so the loopback listener is safe by default. + // Combining this with UseStdio=true is rejected (stdio is pre-authenticated by transport). + TCPConnectionToken string // CLIUrl is the URL of an existing Copilot CLI server to connect to over TCP // Format: "host:port", "http://host:port", or just "port" (defaults to localhost) // Examples: "localhost:8080", "http://127.0.0.1:9000", "8080" diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index 903d8e271..9bd21becb 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -9,7 +9,7 @@ "version": "0.1.8", "license": "MIT", "dependencies": { - "@github/copilot": "^1.0.40", + "@github/copilot": "^1.0.41-0", "vscode-jsonrpc": "^8.2.1", "zod": "^4.3.6" }, @@ -663,26 +663,26 @@ } }, "node_modules/@github/copilot": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot/-/copilot-1.0.40.tgz", - "integrity": "sha512-s35c/9R5q8O2ZQi/rzU9TBpum/DU06dczDSfmepCTisRHVTTKWS703J7kZhZYqL6OIGqnhmLfx4A7afT8YVKKA==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot/-/copilot-1.0.41-0.tgz", + "integrity": "sha512-gLyCadBZdJeJtHJI3XdN8wAmLMEUdXfCa3EcVnbdbV1NHZDAJhr7h41l7a49pqRAmJyLUKlk1Lokk7U+OD3tgw==", "license": "SEE LICENSE IN LICENSE.md", "bin": { "copilot": "npm-loader.js" }, "optionalDependencies": { - "@github/copilot-darwin-arm64": "1.0.40", - "@github/copilot-darwin-x64": "1.0.40", - "@github/copilot-linux-arm64": "1.0.40", - "@github/copilot-linux-x64": "1.0.40", - "@github/copilot-win32-arm64": "1.0.40", - "@github/copilot-win32-x64": "1.0.40" + "@github/copilot-darwin-arm64": "1.0.41-0", + "@github/copilot-darwin-x64": "1.0.41-0", + "@github/copilot-linux-arm64": "1.0.41-0", + "@github/copilot-linux-x64": "1.0.41-0", + "@github/copilot-win32-arm64": "1.0.41-0", + "@github/copilot-win32-x64": "1.0.41-0" } }, "node_modules/@github/copilot-darwin-arm64": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot-darwin-arm64/-/copilot-darwin-arm64-1.0.40.tgz", - "integrity": "sha512-syHKff/G53VzosHRXG6pX+MEc00tMdly0SZS4jC0fFUSY2/6R7lgi4IEZt57Q6WKdjBiqn8EvEQ94efdRdPEzA==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot-darwin-arm64/-/copilot-darwin-arm64-1.0.41-0.tgz", + "integrity": "sha512-lrrH1oMbTOF1W/YxH6rvoEHOymxmXaMx4aDzm190hU0Yh6Cuu0BJGFvgG8nE9bqcv5O8W7eEBr26jDlGtnZiwg==", "cpu": [ "arm64" ], @@ -696,9 +696,9 @@ } }, "node_modules/@github/copilot-darwin-x64": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot-darwin-x64/-/copilot-darwin-x64-1.0.40.tgz", - "integrity": "sha512-ai6PgHLx5SgC7Ht3Hy2tNepdnAnqcWaPPtYFaP2UGS69r1O87JBA/pB2QHVZP3vX34/4RyoOB/WQ1kiT5pvcpg==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot-darwin-x64/-/copilot-darwin-x64-1.0.41-0.tgz", + "integrity": "sha512-4418VtSSkEgn4BcwCFg+0UDhGCfQgGTx16r/PiWbuUOgIBzts3FfVzWMWTuXyxk7kl2Ib8k7KSd/7rNpjcrzBw==", "cpu": [ "x64" ], @@ -712,9 +712,9 @@ } }, "node_modules/@github/copilot-linux-arm64": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot-linux-arm64/-/copilot-linux-arm64-1.0.40.tgz", - "integrity": "sha512-mAh8GmGmkkUiFFzBIvXYmReSIw5c3NWHme0iYT2v/72RWNAwRq4HLKP9NhHapZqUPyi7jQnRxllYPybZDFS6Mw==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot-linux-arm64/-/copilot-linux-arm64-1.0.41-0.tgz", + "integrity": "sha512-5xjgp3Ak5QJ68byNbsgBpdK1V6T5t8EGu0pUwEJMNMMXxqvL9f7gPcnCGdTtV2DS4Q3adkziV/gpBSSQ5HY8hg==", "cpu": [ "arm64" ], @@ -728,9 +728,9 @@ } }, "node_modules/@github/copilot-linux-x64": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot-linux-x64/-/copilot-linux-x64-1.0.40.tgz", - "integrity": "sha512-gCo6RgpXwa39FZSdj5dMkN/z0xT/NS29MpkeYQ/S34SSoUsSbIKUWcuoMjRiXpaDWWuXQsFjXcqnwNs67JOejA==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot-linux-x64/-/copilot-linux-x64-1.0.41-0.tgz", + "integrity": "sha512-oWPkj0bSjBjtAqonMEZD7EuSByBNXwtceMw8y7uGOfs6jQXfhDGzCCB6NGb+lcftVNtWDKFCUtx+x8Fbt4O37w==", "cpu": [ "x64" ], @@ -744,9 +744,9 @@ } }, "node_modules/@github/copilot-win32-arm64": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot-win32-arm64/-/copilot-win32-arm64-1.0.40.tgz", - "integrity": "sha512-59ANg2xfeeWwl8UNqVe4sk2OAizgMjKVRgTALYExHMc8p5HJyL8nrQ9np6pIkO/De0UpR8rUzk4D8oCqU7XLIQ==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot-win32-arm64/-/copilot-win32-arm64-1.0.41-0.tgz", + "integrity": "sha512-MaPg4tFWTiRuyv+j0ymJbZp8UPK+RIXNMpekR7FRf8/Uz+NiJgTTxTDjFi4ytRJU5UNrUezkVAk5Xduq/CaIew==", "cpu": [ "arm64" ], @@ -760,9 +760,9 @@ } }, "node_modules/@github/copilot-win32-x64": { - "version": "1.0.40", - "resolved": "https://registry.npmjs.org/@github/copilot-win32-x64/-/copilot-win32-x64-1.0.40.tgz", - "integrity": "sha512-dKS5L7SwzXZI/gaY9LIn8eoTGEPgczGLIgt1KwjfkHgk3achHwIuEjxSqPRVD1Z7q08uuOcuwVzOlwE8V38zNQ==", + "version": "1.0.41-0", + "resolved": "https://registry.npmjs.org/@github/copilot-win32-x64/-/copilot-win32-x64-1.0.41-0.tgz", + "integrity": "sha512-ykRuDWjJEgSywMFJl1yaefssaklCVSVhprx2NcSVh6tIGupvvzVAM6nL6Mj6nyKpG6FKGHanedBeL6SJc935cw==", "cpu": [ "x64" ], diff --git a/nodejs/package.json b/nodejs/package.json index 90e4d3667..eb41d5346 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -56,7 +56,7 @@ "author": "GitHub", "license": "MIT", "dependencies": { - "@github/copilot": "^1.0.40", + "@github/copilot": "^1.0.41-0", "vscode-jsonrpc": "^8.2.1", "zod": "^4.3.6" }, diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 51f1a138a..34f0c62d3 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -20,11 +20,17 @@ import { dirname, join } from "node:path"; import { fileURLToPath } from "node:url"; import { createMessageConnection, + ErrorCodes, MessageConnection, + ResponseError, StreamMessageReader, StreamMessageWriter, } from "vscode-jsonrpc/node.js"; -import { createServerRpc, registerClientSessionApiHandlers } from "./generated/rpc.js"; +import { + createServerRpc, + createInternalServerRpc, + registerClientSessionApiHandlers, +} from "./generated/rpc.js"; import { getSdkProtocolVersion } from "./sdkProtocolVersion.js"; import { CopilotSession, NO_RESULT_PERMISSION_V2_ERROR } from "./session.js"; import { createSessionFsAdapter } from "./sessionFsProvider.js"; @@ -221,6 +227,7 @@ export class CopilotClient { | "telemetry" | "onGetTraceContext" | "sessionFs" + | "tcpConnectionToken" > > & { cliPath?: string; @@ -231,6 +238,8 @@ export class CopilotClient { }; private isExternalServer: boolean = false; private forceStopping: boolean = false; + /** Token sent in `connect`; auto-generated when the SDK spawns its own CLI in TCP mode. */ + private effectiveConnectionToken?: string; private onListModels?: () => Promise | ModelInfo[]; private onGetTraceContext?: TraceContextProvider; private modelsCache: ModelInfo[] | null = null; @@ -241,6 +250,7 @@ export class CopilotClient { Set<(event: SessionLifecycleEvent) => void> > = new Map(); private _rpc: ReturnType | null = null; + private _internalRpc: ReturnType | null = null; private processExitPromise: Promise | null = null; // Rejects when CLI process exits private negotiatedProtocolVersion: number | null = null; /** Connection-level session filesystem config, set via constructor option. */ @@ -260,6 +270,20 @@ export class CopilotClient { return this._rpc; } + /** + * Internal RPC surface (e.g. handshake helpers). Not part of the public API. + * @internal + */ + private get internalRpc(): ReturnType { + if (!this.connection) { + throw new Error("Client is not connected. Call start() first."); + } + if (!this._internalRpc) { + this._internalRpc = createInternalServerRpc(this.connection); + } + return this._internalRpc; + } + /** * Creates a new CopilotClient instance. * @@ -300,6 +324,23 @@ export class CopilotClient { ); } + if (options.tcpConnectionToken !== undefined) { + if ( + typeof options.tcpConnectionToken !== "string" || + options.tcpConnectionToken.length === 0 + ) { + throw new Error("tcpConnectionToken must be a non-empty string"); + } + if (options.useStdio === true) { + throw new Error("tcpConnectionToken cannot be used with useStdio: true"); + } + } + + const willUseStdio = options.cliUrl ? false : (options.useStdio ?? true); + const sdkSpawnsCli = !willUseStdio && !options.cliUrl && !options.isChildProcess; + this.effectiveConnectionToken = + options.tcpConnectionToken ?? (sdkSpawnsCli ? randomUUID() : undefined); + if (options.sessionFs) { this.validateSessionFsConfig(options.sessionFs); } @@ -1066,22 +1107,34 @@ export class CopilotClient { } /** - * Verify that the server's protocol version is within the supported range - * and store the negotiated version. + * Send the `connect` handshake (carrying the optional token) and verify the + * server's protocol version. Falls back to `ping` against legacy servers + * that don't implement `connect`. */ private async verifyProtocolVersion(): Promise { + if (!this.connection) { + throw new Error("Client not connected"); + } const maxVersion = getSdkProtocolVersion(); + const raceAgainstExit = (p: Promise): Promise => + this.processExitPromise ? Promise.race([p, this.processExitPromise]) : p; - // Race ping against process exit to detect early CLI failures - let pingResult: Awaited>; - if (this.processExitPromise) { - pingResult = await Promise.race([this.ping(), this.processExitPromise]); - } else { - pingResult = await this.ping(); + let serverVersion: number | undefined; + try { + const result = await raceAgainstExit( + this.internalRpc.connect({ token: this.effectiveConnectionToken }) + ); + serverVersion = result.protocolVersion; + } catch (err) { + if (err instanceof ResponseError && err.code === ErrorCodes.MethodNotFound) { + // Legacy server without `connect`; fall back to `ping`. A token, if any, + // is silently dropped — the legacy server can't enforce one. + serverVersion = (await raceAgainstExit(this.ping())).protocolVersion; + } else { + throw err; + } } - const serverVersion = pingResult.protocolVersion; - if (serverVersion === undefined) { throw new Error( `SDK protocol version mismatch: SDK supports versions ${MIN_PROTOCOL_VERSION}-${maxVersion}, but server does not report a protocol version. ` + @@ -1439,6 +1492,10 @@ export class CopilotClient { envWithoutNodeDebug.COPILOT_SDK_AUTH_TOKEN = this.options.gitHubToken; } + if (this.effectiveConnectionToken) { + envWithoutNodeDebug.COPILOT_CONNECTION_TOKEN = this.effectiveConnectionToken; + } + if (!this.options.cliPath) { throw new Error( "Path to Copilot CLI is required. Please provide it via the cliPath option, or use cliUrl to rely on a remote CLI." diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index bc218c777..6836324ab 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -395,6 +395,30 @@ export interface CommandsHandlePendingCommandResult { success: boolean; } +/** @internal */ +export interface ConnectRequest { + /** + * Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN + */ + token?: string; +} + +/** @internal */ +export interface ConnectResult { + /** + * Always true on success + */ + ok: true; + /** + * Server protocol version number + */ + protocolVersion: number; + /** + * Server package version + */ + version: string; +} + export interface CurrentModel { /** * Currently active model identifier @@ -2521,6 +2545,18 @@ export function createServerRpc(connection: MessageConnection) { }; } +/** + * Create typed server-scoped RPC methods that are part of the SDK's internal + * surface (e.g. handshake helpers). Not exported on the public client API. + * @internal + */ +export function createInternalServerRpc(connection: MessageConnection) { + return { + connect: async (params: ConnectRequest): Promise => + connection.sendRequest("connect", params), + }; +} + /** Create typed session-scoped RPC methods. */ export function createSessionRpc(connection: MessageConnection, sessionId: string) { return { diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 8a98fc692..4aee16006 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -194,6 +194,14 @@ export interface CopilotClientOptions { * @default undefined (disabled) */ sessionIdleTimeoutSeconds?: number; + + /** + * Connection token for the headless CLI server (TCP only). When the SDK + * spawns its own CLI in TCP mode and this is omitted, a UUID is generated + * automatically so the loopback listener is safe by default. Rejected with + * `useStdio: true` (stdio is pre-authenticated by transport). + */ + tcpConnectionToken?: string; } /** diff --git a/nodejs/test/e2e/client_options.e2e.test.ts b/nodejs/test/e2e/client_options.e2e.test.ts index 823e22c01..73ddc09db 100644 --- a/nodejs/test/e2e/client_options.e2e.test.ts +++ b/nodejs/test/e2e/client_options.e2e.test.ts @@ -81,6 +81,11 @@ function handleMessage(message) { requests.push({ method: message.method, params: message.params }); saveCapture(); + if (message.method === "connect") { + writeResponse(message.id, { ok: true, protocolVersion: 3, version: "fake" }); + return; + } + if (message.method === "ping") { writeResponse(message.id, { message: "pong", protocolVersion: 3 }); return; diff --git a/nodejs/test/e2e/commands.e2e.test.ts b/nodejs/test/e2e/commands.e2e.test.ts index b98c6c6d0..5ab6a9bbe 100644 --- a/nodejs/test/e2e/commands.e2e.test.ts +++ b/nodejs/test/e2e/commands.e2e.test.ts @@ -9,15 +9,19 @@ import { createSdkTestContext } from "./harness/sdkTestContext.js"; describe("Commands", async () => { // Use TCP mode so a second client can connect to the same CLI process - const ctx = await createSdkTestContext({ useStdio: false }); + const tcpConnectionToken = "commands-test-token"; + const ctx = await createSdkTestContext({ + useStdio: false, + copilotClientOptions: { tcpConnectionToken }, + }); const client1 = ctx.copilotClient; // Trigger connection so we can read the port const initSession = await client1.createSession({ onPermissionRequest: approveAll }); await initSession.disconnect(); - const actualPort = (client1 as unknown as { actualPort: number }).actualPort; - const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + const { actualPort } = client1 as unknown as { actualPort: number }; + const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); afterAll(async () => { await client2.stop(); diff --git a/nodejs/test/e2e/connection_token.test.ts b/nodejs/test/e2e/connection_token.test.ts new file mode 100644 index 000000000..50813778c --- /dev/null +++ b/nodejs/test/e2e/connection_token.test.ts @@ -0,0 +1,49 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { afterAll, describe, expect, it } from "vitest"; +import { CopilotClient } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +describe("Connection token", async () => { + const ctx = await createSdkTestContext({ + useStdio: false, + copilotClientOptions: { tcpConnectionToken: "right-token" }, + }); + const goodClient = ctx.copilotClient; + await goodClient.start(); + const port = (goodClient as unknown as { actualPort: number }).actualPort; + + const wrongClient = new CopilotClient({ + cliUrl: `localhost:${port}`, + tcpConnectionToken: "wrong", + }); + const noTokenClient = new CopilotClient({ cliUrl: `localhost:${port}` }); + + afterAll(async () => { + await wrongClient.forceStop(); + await noTokenClient.forceStop(); + }); + + it("connects with the matching token", async () => { + await expect(goodClient.ping("hi")).resolves.toMatchObject({ message: "pong: hi" }); + }); + + it("rejects a wrong token", async () => { + await expect(wrongClient.start()).rejects.toThrow(/AUTHENTICATION_FAILED/); + }); + + it("rejects a missing token when one is required", async () => { + await expect(noTokenClient.start()).rejects.toThrow(/AUTHENTICATION_FAILED/); + }); +}); + +describe("Connection token (auto-generated)", async () => { + const { copilotClient } = await createSdkTestContext({ useStdio: false }); + + it("the SDK-auto-generated UUID round-trips through the spawned CLI", async () => { + await copilotClient.start(); + await expect(copilotClient.ping("hi")).resolves.toMatchObject({ message: "pong: hi" }); + }); +}); diff --git a/nodejs/test/e2e/multi-client.e2e.test.ts b/nodejs/test/e2e/multi-client.e2e.test.ts index f23ae4459..14e1a3754 100644 --- a/nodejs/test/e2e/multi-client.e2e.test.ts +++ b/nodejs/test/e2e/multi-client.e2e.test.ts @@ -10,7 +10,11 @@ import { createSdkTestContext } from "./harness/sdkTestContext"; describe("Multi-client broadcast", async () => { // Use TCP mode so a second client can connect to the same CLI process - const ctx = await createSdkTestContext({ useStdio: false }); + const tcpConnectionToken = "multi-client-test-token"; + const ctx = await createSdkTestContext({ + useStdio: false, + copilotClientOptions: { tcpConnectionToken }, + }); const client1 = ctx.copilotClient; // Trigger connection so we can read the port @@ -18,7 +22,7 @@ describe("Multi-client broadcast", async () => { await initSession.disconnect(); const actualPort = (client1 as unknown as { actualPort: number }).actualPort; - let client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + let client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); afterAll(async () => { await client2.stop(); @@ -297,7 +301,7 @@ describe("Multi-client broadcast", async () => { process.removeListener("unhandledRejection", suppressDisposed); // Recreate client2 for cleanup in afterAll (but don't rejoin the session) - client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); // Now only stable_tool should be available const afterResponse = await session1.sendAndWait({ diff --git a/nodejs/test/e2e/pending_work_resume.e2e.test.ts b/nodejs/test/e2e/pending_work_resume.e2e.test.ts index 10f8de026..b81cdb5e0 100644 --- a/nodejs/test/e2e/pending_work_resume.e2e.test.ts +++ b/nodejs/test/e2e/pending_work_resume.e2e.test.ts @@ -123,6 +123,7 @@ function waitForPermissionRequest(session: CopilotSession): Promise { const { env, workDir } = await createSdkTestContext(); + const SHARED_TOKEN = "pending-work-resume-shared-test-token"; function createTcpServer(): CopilotClient { const server = new CopilotClient({ @@ -130,6 +131,7 @@ describe("Pending work resume", async () => { env, cliPath: process.env.COPILOT_CLI_PATH, useStdio: false, + tcpConnectionToken: SHARED_TOKEN, }); onTestFinished(async () => { try { @@ -142,7 +144,7 @@ describe("Pending work resume", async () => { } function createConnectingClient(cliUrl: string): CopilotClient { - const client = new CopilotClient({ cliUrl }); + const client = new CopilotClient({ cliUrl, tcpConnectionToken: SHARED_TOKEN }); onTestFinished(async () => { try { await client.forceStop(); diff --git a/nodejs/test/e2e/session_fs.e2e.test.ts b/nodejs/test/e2e/session_fs.e2e.test.ts index fdadc3db2..a28a2713c 100644 --- a/nodejs/test/e2e/session_fs.e2e.test.ts +++ b/nodejs/test/e2e/session_fs.e2e.test.ts @@ -88,15 +88,16 @@ describe("Session Fs", async () => { }); it("should reject setProvider when sessions already exist", async () => { + const tcpConnectionToken = "session-fs-test-token"; const client = new CopilotClient({ useStdio: false, // Use TCP so we can connect from a second client + tcpConnectionToken, env, }); onTestFinished(() => client.forceStop()); await client.createSession({ onPermissionRequest: approveAll, createSessionFsHandler }); - // Get the port the first client's runtime is listening on - const port = (client as unknown as { actualPort: number }).actualPort; + const { actualPort: port } = client as unknown as { actualPort: number }; // Second client tries to connect with a session fs — should fail // because sessions already exist on the runtime. @@ -104,6 +105,7 @@ describe("Session Fs", async () => { env, logLevel: "error", cliUrl: `localhost:${port}`, + tcpConnectionToken, sessionFs: sessionFsConfig, }); onTestFinished(() => client2.forceStop()); diff --git a/nodejs/test/e2e/suspend.e2e.test.ts b/nodejs/test/e2e/suspend.e2e.test.ts index cc7977d79..3ca4c4e3f 100644 --- a/nodejs/test/e2e/suspend.e2e.test.ts +++ b/nodejs/test/e2e/suspend.e2e.test.ts @@ -59,6 +59,7 @@ function onTestFinishedForceStop(client: CopilotClient): void { describe("Suspend RPC", async () => { const { copilotClient: client, env, workDir } = await createSdkTestContext(); + const SHARED_TOKEN = "suspend-shared-test-token"; function createTcpServer(): CopilotClient { const server = new CopilotClient({ @@ -66,13 +67,14 @@ describe("Suspend RPC", async () => { env, cliPath: process.env.COPILOT_CLI_PATH, useStdio: false, + tcpConnectionToken: SHARED_TOKEN, }); onTestFinishedForceStop(server); return server; } function createConnectingClient(cliUrl: string): CopilotClient { - const connectedClient = new CopilotClient({ cliUrl }); + const connectedClient = new CopilotClient({ cliUrl, tcpConnectionToken: SHARED_TOKEN }); onTestFinishedForceStop(connectedClient); return connectedClient; } diff --git a/nodejs/test/e2e/ui_elicitation.e2e.test.ts b/nodejs/test/e2e/ui_elicitation.e2e.test.ts index ced735d88..8651c5bd2 100644 --- a/nodejs/test/e2e/ui_elicitation.e2e.test.ts +++ b/nodejs/test/e2e/ui_elicitation.e2e.test.ts @@ -53,15 +53,19 @@ describe("UI Elicitation Callback", async () => { describe("UI Elicitation Multi-Client Capabilities", async () => { // Use TCP mode so a second client can connect to the same CLI process - const ctx = await createSdkTestContext({ useStdio: false }); + const tcpConnectionToken = "ui-elicitation-test-token"; + const ctx = await createSdkTestContext({ + useStdio: false, + copilotClientOptions: { tcpConnectionToken }, + }); const client1 = ctx.copilotClient; // Trigger connection so we can read the port const initSession = await client1.createSession({ onPermissionRequest: approveAll }); await initSession.disconnect(); - const actualPort = (client1 as unknown as { actualPort: number }).actualPort; - const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + const { actualPort } = client1 as unknown as { actualPort: number }; + const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); afterAll(async () => { await client2.stop(); @@ -134,7 +138,10 @@ describe("UI Elicitation Multi-Client Capabilities", async () => { }); // Use a dedicated client so we can stop it without affecting shared client2 - const client3 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + const client3 = new CopilotClient({ + cliUrl: `localhost:${actualPort}`, + tcpConnectionToken, + }); // Client3 joins WITH elicitation handler await client3.resumeSession(session1.sessionId, { diff --git a/python/copilot/client.py b/python/copilot/client.py index d37525343..881e172c4 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -29,12 +29,14 @@ from types import TracebackType from typing import Any, Literal, TypedDict, cast, overload -from ._jsonrpc import JsonRpcClient, ProcessExitedError +from ._jsonrpc import JsonRpcClient, JsonRpcError, ProcessExitedError from ._sdk_protocol_version import get_sdk_protocol_version from ._telemetry import get_trace_context, trace_context from .generated.rpc import ( ClientSessionApiHandlers, + ConnectRequest, ServerRpc, + _InternalServerRpc, register_client_session_api_handlers, ) from .generated.session_events import ( @@ -126,6 +128,14 @@ class SubprocessConfig: use_stdio: bool = True """Use stdio transport (``True``, default) or TCP (``False``).""" + tcp_connection_token: str | None = None + """Connection token for the headless CLI server (TCP only). + + Only meaningful when ``use_stdio=False``. When the SDK spawns the CLI in TCP mode and + this is omitted, a UUID is generated automatically so the loopback listener is safe by + default. Combining this with ``use_stdio=True`` raises :class:`ValueError`. + """ + port: int = 0 """TCP port for the CLI server (only when ``use_stdio=False``). 0 means random.""" @@ -173,6 +183,10 @@ class ExternalServerConfig: _: KW_ONLY + tcp_connection_token: str | None = None + """Connection token sent in the ``connect`` handshake. Required when the server was + started with a token; ignored by legacy servers without ``connect`` support.""" + session_fs: SessionFsConfig | None = None """Connection-level session filesystem provider configuration.""" @@ -880,12 +894,25 @@ def __init__( self._actual_host: str = "localhost" self._is_external_server: bool = isinstance(config, ExternalServerConfig) + if config.tcp_connection_token is not None and len(config.tcp_connection_token) == 0: + raise ValueError("tcp_connection_token must be a non-empty string") + if isinstance(config, ExternalServerConfig): self._actual_host, actual_port = self._parse_cli_url(config.url) self._actual_port: int | None = actual_port + self._effective_connection_token: str | None = config.tcp_connection_token else: self._actual_port = None + if config.tcp_connection_token is not None and config.use_stdio: + raise ValueError("tcp_connection_token cannot be used with use_stdio=True") + if config.use_stdio: + self._effective_connection_token = None + elif config.tcp_connection_token is not None: + self._effective_connection_token = config.tcp_connection_token + else: + self._effective_connection_token = str(uuid.uuid4()) + # Resolve CLI path: explicit > COPILOT_CLI_PATH env var > bundled binary effective_env = config.env if config.env is not None else os.environ if config.cli_path is None: @@ -2177,11 +2204,27 @@ def _dispatch_lifecycle_event(self, event: SessionLifecycleEvent) -> None: pass # Ignore handler errors async def _verify_protocol_version(self) -> None: - """Verify that the server's protocol version is within the supported range - and store the negotiated version.""" + """Send the ``connect`` handshake (with the optional token) and verify + the server's protocol version. Falls back to ``ping`` for legacy servers + that don't implement ``connect``.""" + if not self._client: + raise RuntimeError("Client not connected") max_version = get_sdk_protocol_version() - ping_result = await self.ping() - server_version = ping_result.protocolVersion + + server_version: int | None + try: + connect_result = await _InternalServerRpc(self._client).connect( + ConnectRequest(token=self._effective_connection_token) + ) + server_version = connect_result.protocol_version + except JsonRpcError as err: + if err.code == -32601: + # Legacy server without `connect`; fall back to `ping`. A token, if any, + # is silently dropped — the legacy server can't enforce one. + ping_result = await self.ping() + server_version = ping_result.protocolVersion + else: + raise if server_version is None: raise RuntimeError( @@ -2333,6 +2376,9 @@ async def _start_cli_server(self) -> None: if cfg.github_token: env["COPILOT_SDK_AUTH_TOKEN"] = cfg.github_token + if self._effective_connection_token: + env["COPILOT_CONNECTION_TOKEN"] = self._effective_connection_token + # Set OpenTelemetry environment variables if telemetry config is provided telemetry = cfg.telemetry if telemetry is not None: diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 1b39aa4d3..fc3eb7bdf 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -245,6 +245,51 @@ def to_dict(self) -> dict: result["success"] = from_bool(self.success) return result +# Internal: this type is an internal SDK API and is not part of the public surface. +@dataclass +class ConnectRequest: + token: str | None = None + """Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN""" + + @staticmethod + def from_dict(obj: Any) -> 'ConnectRequest': + assert isinstance(obj, dict) + token = from_union([from_str, from_none], obj.get("token")) + return ConnectRequest(token) + + def to_dict(self) -> dict: + result: dict = {} + if self.token is not None: + result["token"] = from_union([from_str, from_none], self.token) + return result + +# Internal: this type is an internal SDK API and is not part of the public surface. +@dataclass +class ConnectResult: + ok: bool + """Always true on success""" + + protocol_version: int + """Server protocol version number""" + + version: str + """Server package version""" + + @staticmethod + def from_dict(obj: Any) -> 'ConnectResult': + assert isinstance(obj, dict) + ok = from_bool(obj.get("ok")) + protocol_version = from_int(obj.get("protocolVersion")) + version = from_str(obj.get("version")) + return ConnectResult(ok, protocol_version, version) + + def to_dict(self) -> dict: + result: dict = {} + result["ok"] = from_bool(self.ok) + result["protocolVersion"] = from_int(self.protocol_version) + result["version"] = from_str(self.version) + return result + @dataclass class CurrentModel: model_id: str | None = None @@ -5562,6 +5607,8 @@ class RPC: auth_info_type: AuthInfoType commands_handle_pending_command_request: CommandsHandlePendingCommandRequest commands_handle_pending_command_result: CommandsHandlePendingCommandResult + connect_request: ConnectRequest + connect_result: ConnectResult current_model: CurrentModel discovered_mcp_server: DiscoveredMCPServer discovered_mcp_server_source: MCPServerSource @@ -5788,6 +5835,8 @@ def from_dict(obj: Any) -> 'RPC': auth_info_type = AuthInfoType(obj.get("AuthInfoType")) commands_handle_pending_command_request = CommandsHandlePendingCommandRequest.from_dict(obj.get("CommandsHandlePendingCommandRequest")) commands_handle_pending_command_result = CommandsHandlePendingCommandResult.from_dict(obj.get("CommandsHandlePendingCommandResult")) + connect_request = ConnectRequest.from_dict(obj.get("ConnectRequest")) + connect_result = ConnectResult.from_dict(obj.get("ConnectResult")) current_model = CurrentModel.from_dict(obj.get("CurrentModel")) discovered_mcp_server = DiscoveredMCPServer.from_dict(obj.get("DiscoveredMcpServer")) discovered_mcp_server_source = MCPServerSource(obj.get("DiscoveredMcpServerSource")) @@ -5998,7 +6047,7 @@ def from_dict(obj: Any) -> 'RPC': workspaces_list_files_result = WorkspacesListFilesResult.from_dict(obj.get("WorkspacesListFilesResult")) workspaces_read_file_request = WorkspacesReadFileRequest.from_dict(obj.get("WorkspacesReadFileRequest")) workspaces_read_file_result = WorkspacesReadFileResult.from_dict(obj.get("WorkspacesReadFileResult")) - return RPC(account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_get_current_result, agent_info, agent_list, agent_reload_result, agent_select_request, agent_select_result, auth_info_type, commands_handle_pending_command_request, commands_handle_pending_command_result, current_model, discovered_mcp_server, discovered_mcp_server_source, discovered_mcp_server_type, embedded_blob_resource_contents, embedded_text_resource_contents, extension, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, filter_mapping_string, filter_mapping_value, fleet_start_request, fleet_start_result, handle_pending_tool_call_request, handle_pending_tool_call_result, history_compact_context_window, history_compact_result, history_truncate_request, history_truncate_result, instructions_get_sources_result, instructions_sources, instructions_sources_location, instructions_sources_type, log_request, log_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_oauth_login_request, mcp_oauth_login_result, mcp_server, mcp_server_config, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_local, mcp_server_config_local_type, mcp_server_list, mcp_server_source, mcp_server_status, model, model_billing, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_policy, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, name_get_result, name_set_request, permission_decision, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_request_result, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_approve_all_request, permissions_set_approve_all_result, ping_request, ping_result, plan_read_result, plan_update_request, plugin, plugin_list, server_skill, server_skill_list, session_auth_status, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_log_level, session_mode, sessions_fork_request, sessions_fork_result, shell_exec_request, shell_exec_result, shell_kill_request, shell_kill_result, shell_kill_signal, skill, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, task_agent_info, task_agent_info_execution_mode, task_agent_info_status, task_info, task_list, tasks_cancel_request, tasks_cancel_result, task_shell_info, task_shell_info_attachment_mode, task_shell_info_execution_mode, task_shell_info_status, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_remove_request, tasks_remove_result, tasks_start_agent_request, tasks_start_agent_result, tool, tool_list, tools_list_request, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_handle_pending_elicitation_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, workspaces_create_file_request, workspaces_get_workspace_result, workspaces_list_files_result, workspaces_read_file_request, workspaces_read_file_result) + return RPC(account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_get_current_result, agent_info, agent_list, agent_reload_result, agent_select_request, agent_select_result, auth_info_type, commands_handle_pending_command_request, commands_handle_pending_command_result, connect_request, connect_result, current_model, discovered_mcp_server, discovered_mcp_server_source, discovered_mcp_server_type, embedded_blob_resource_contents, embedded_text_resource_contents, extension, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, filter_mapping_string, filter_mapping_value, fleet_start_request, fleet_start_result, handle_pending_tool_call_request, handle_pending_tool_call_result, history_compact_context_window, history_compact_result, history_truncate_request, history_truncate_result, instructions_get_sources_result, instructions_sources, instructions_sources_location, instructions_sources_type, log_request, log_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_oauth_login_request, mcp_oauth_login_result, mcp_server, mcp_server_config, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_local, mcp_server_config_local_type, mcp_server_list, mcp_server_source, mcp_server_status, model, model_billing, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_policy, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, name_get_result, name_set_request, permission_decision, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_request_result, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_approve_all_request, permissions_set_approve_all_result, ping_request, ping_result, plan_read_result, plan_update_request, plugin, plugin_list, server_skill, server_skill_list, session_auth_status, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_log_level, session_mode, sessions_fork_request, sessions_fork_result, shell_exec_request, shell_exec_result, shell_kill_request, shell_kill_result, shell_kill_signal, skill, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, task_agent_info, task_agent_info_execution_mode, task_agent_info_status, task_info, task_list, tasks_cancel_request, tasks_cancel_result, task_shell_info, task_shell_info_attachment_mode, task_shell_info_execution_mode, task_shell_info_status, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_remove_request, tasks_remove_result, tasks_start_agent_request, tasks_start_agent_result, tool, tool_list, tools_list_request, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_handle_pending_elicitation_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, workspaces_create_file_request, workspaces_get_workspace_result, workspaces_list_files_result, workspaces_read_file_request, workspaces_read_file_result) def to_dict(self) -> dict: result: dict = {} @@ -6014,6 +6063,8 @@ def to_dict(self) -> dict: result["AuthInfoType"] = to_enum(AuthInfoType, self.auth_info_type) result["CommandsHandlePendingCommandRequest"] = to_class(CommandsHandlePendingCommandRequest, self.commands_handle_pending_command_request) result["CommandsHandlePendingCommandResult"] = to_class(CommandsHandlePendingCommandResult, self.commands_handle_pending_command_result) + result["ConnectRequest"] = to_class(ConnectRequest, self.connect_request) + result["ConnectResult"] = to_class(ConnectResult, self.connect_result) result["CurrentModel"] = to_class(CurrentModel, self.current_model) result["DiscoveredMcpServer"] = to_class(DiscoveredMCPServer, self.discovered_mcp_server) result["DiscoveredMcpServerSource"] = to_enum(MCPServerSource, self.discovered_mcp_server_source) @@ -6381,6 +6432,17 @@ async def ping(self, params: PingRequest, *, timeout: float | None = None) -> Pi return PingResult.from_dict(await self._client.request("ping", params_dict, **_timeout_kwargs(timeout))) +class _InternalServerRpc: + """Internal SDK server-scoped RPC methods (handshake helpers etc.). Not part of the public API.""" + def __init__(self, client: "JsonRpcClient"): + self._client = client + + async def connect(self, params: ConnectRequest, *, timeout: float | None = None) -> ConnectResult: + """:meta private: Internal SDK API; not part of the public surface.""" + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + return ConnectResult.from_dict(await self._client.request("connect", params_dict, **_timeout_kwargs(timeout))) + + class AuthApi: def __init__(self, client: "JsonRpcClient", session_id: str): self._client = client diff --git a/python/e2e/test_client_options_e2e.py b/python/e2e/test_client_options_e2e.py index 0a002c98f..b5f3a6011 100644 --- a/python/e2e/test_client_options_e2e.py +++ b/python/e2e/test_client_options_e2e.py @@ -110,6 +110,10 @@ def _get_available_port() -> int: } requests.push({ method: message.method, params: message.params }); saveCapture(); + if (message.method === "connect") { + writeResponse(message.id, { ok: true, protocolVersion: 3, version: "fake" }); + return; + } if (message.method === "ping") { writeResponse(message.id, { message: "pong", protocolVersion: 3, timestamp: Date.now() }); return; diff --git a/python/e2e/test_commands_e2e.py b/python/e2e/test_commands_e2e.py index 39c6463f7..a1c44b7b3 100644 --- a/python/e2e/test_commands_e2e.py +++ b/python/e2e/test_commands_e2e.py @@ -62,6 +62,7 @@ async def setup(self): env=self._get_env(), use_stdio=False, github_token=github_token, + tcp_connection_token="py-tcp-shared-test-token", ) ) @@ -74,7 +75,11 @@ async def setup(self): actual_port = self._client1.actual_port assert actual_port is not None - self._client2 = CopilotClient(ExternalServerConfig(url=f"localhost:{actual_port}")) + self._client2 = CopilotClient( + ExternalServerConfig( + url=f"localhost:{actual_port}", tcp_connection_token="py-tcp-shared-test-token" + ) + ) async def teardown(self, test_failed: bool = False): for c in (self._client2, self._client1): diff --git a/python/e2e/test_connection_token.py b/python/e2e/test_connection_token.py new file mode 100644 index 000000000..814af5965 --- /dev/null +++ b/python/e2e/test_connection_token.py @@ -0,0 +1,168 @@ +"""E2E Connection Token Tests + +Tests for the optional TCP ``connect`` token handshake. Mirrors the Node SDK's +``connection_token.test.ts``. +""" + +import os +import shutil +import tempfile + +import pytest +import pytest_asyncio + +from copilot import CopilotClient +from copilot.client import ExternalServerConfig, SubprocessConfig +from copilot.session import PermissionHandler + +from .testharness.proxy import CapiProxy + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class ConnectionTokenContext: + """Spawns a TCP CLI server with an explicit connection token.""" + + def __init__(self, token: str | None): + self.token = token + self.cli_path: str = "" + self.home_dir: str = "" + self.work_dir: str = "" + self.proxy_url: str = "" + self._proxy: CapiProxy | None = None + self._client: CopilotClient | None = None + + async def setup(self): + from .testharness.context import get_cli_path_for_tests + + self.cli_path = get_cli_path_for_tests() + self.home_dir = tempfile.mkdtemp(prefix="copilot-token-config-") + self.work_dir = tempfile.mkdtemp(prefix="copilot-token-work-") + + self._proxy = CapiProxy() + self.proxy_url = await self._proxy.start() + + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + + self._client = CopilotClient( + SubprocessConfig( + cli_path=self.cli_path, + cwd=self.work_dir, + env=self.get_env(), + use_stdio=False, + tcp_connection_token=self.token, + github_token=github_token, + ) + ) + + # Trigger the spawn + connect handshake so the server is listening. + await self._client.start() + + async def teardown(self): + if self._client: + try: + await self._client.stop() + except Exception: + # Best-effort cleanup; ignore stop errors during teardown. + pass + self._client = None + if self._proxy: + await self._proxy.stop(skip_writing_cache=True) + self._proxy = None + if self.home_dir and os.path.exists(self.home_dir): + shutil.rmtree(self.home_dir, ignore_errors=True) + if self.work_dir and os.path.exists(self.work_dir): + shutil.rmtree(self.work_dir, ignore_errors=True) + + def get_env(self) -> dict: + env = os.environ.copy() + env.update( + { + "COPILOT_API_URL": self.proxy_url, + "COPILOT_HOME": self.home_dir, + "XDG_CONFIG_HOME": self.home_dir, + "XDG_STATE_HOME": self.home_dir, + } + ) + return env + + @property + def client(self) -> CopilotClient: + if not self._client: + raise RuntimeError("Context not set up") + return self._client + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def explicit_token_ctx(): + ctx = ConnectionTokenContext(token="right-token") + await ctx.setup() + yield ctx + await ctx.teardown() + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def auto_token_ctx(): + ctx = ConnectionTokenContext(token=None) + await ctx.setup() + yield ctx + await ctx.teardown() + + +class TestConnectionToken: + async def test_explicit_token_round_trips(self, explicit_token_ctx: ConnectionTokenContext): + """Client started with an explicit token can ping successfully.""" + # Sanity-check that the token was forwarded to the spawned CLI and the + # `connect` handshake succeeded; a real ping must round-trip. + response = await explicit_token_ctx.client.ping("hi") + assert response.message == "pong: hi" + + # Bonus: a fresh session round-trip also exercises the live connection. + session = await explicit_token_ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + await session.disconnect() + + async def test_auto_generated_token_round_trips(self, auto_token_ctx: ConnectionTokenContext): + """When the SDK spawns its own CLI in TCP mode without an explicit token, + the auto-generated UUID is forwarded and the `connect` handshake succeeds.""" + response = await auto_token_ctx.client.ping("hi") + assert response.message == "pong: hi" + + async def test_wrong_token_is_rejected(self, explicit_token_ctx: ConnectionTokenContext): + """A sibling client connecting with the wrong token is rejected.""" + port = explicit_token_ctx.client.actual_port + assert port is not None + + wrong = CopilotClient( + ExternalServerConfig(url=f"localhost:{port}", tcp_connection_token="wrong") + ) + try: + with pytest.raises(Exception, match="AUTHENTICATION_FAILED"): + await wrong.start() + finally: + try: + await wrong.force_stop() + except Exception: + # Best-effort cleanup; client startup is expected to fail above, + # so force_stop may raise if no process/session was established. + pass + + async def test_missing_token_is_rejected(self, explicit_token_ctx: ConnectionTokenContext): + """A sibling client with no token is rejected when the server requires one.""" + port = explicit_token_ctx.client.actual_port + assert port is not None + + no_token = CopilotClient(ExternalServerConfig(url=f"localhost:{port}")) + try: + with pytest.raises(Exception, match="AUTHENTICATION_FAILED"): + await no_token.start() + finally: + try: + await no_token.force_stop() + except Exception: + # Best-effort cleanup; client startup is expected to fail above, + # so force_stop may raise if no process/session was established. + pass diff --git a/python/e2e/test_multi_client_e2e.py b/python/e2e/test_multi_client_e2e.py index f57de28d4..b9ecbc5a2 100644 --- a/python/e2e/test_multi_client_e2e.py +++ b/python/e2e/test_multi_client_e2e.py @@ -59,6 +59,7 @@ async def setup(self): env=self.get_env(), use_stdio=False, github_token=github_token, + tcp_connection_token="py-tcp-shared-test-token", ) ) @@ -72,7 +73,11 @@ async def setup(self): actual_port = self._client1.actual_port assert actual_port is not None, "Client 1 should have an actual port after connecting" - self._client2 = CopilotClient(ExternalServerConfig(url=f"localhost:{actual_port}")) + self._client2 = CopilotClient( + ExternalServerConfig( + url=f"localhost:{actual_port}", tcp_connection_token="py-tcp-shared-test-token" + ) + ) async def teardown(self, test_failed: bool = False): if self._client2: @@ -422,7 +427,11 @@ def ephemeral_tool(params: InputParams, invocation: ToolInvocation) -> str: # Recreate client2 for future tests (but don't rejoin the session) actual_port = mctx.client1.actual_port - mctx._client2 = CopilotClient(ExternalServerConfig(url=f"localhost:{actual_port}")) + mctx._client2 = CopilotClient( + ExternalServerConfig( + url=f"localhost:{actual_port}", tcp_connection_token="py-tcp-shared-test-token" + ) + ) # Now only stable_tool should be available await session1.send( diff --git a/python/e2e/test_pending_work_resume_e2e.py b/python/e2e/test_pending_work_resume_e2e.py index 28d45bbec..d1c3b812f 100644 --- a/python/e2e/test_pending_work_resume_e2e.py +++ b/python/e2e/test_pending_work_resume_e2e.py @@ -40,6 +40,7 @@ def _make_subprocess_client(ctx: E2ETestContext, *, use_stdio: bool = True) -> C env=ctx.get_env(), github_token=github_token, use_stdio=use_stdio, + tcp_connection_token="py-tcp-shared-test-token", ) ) @@ -147,7 +148,9 @@ async def hold_permission(request, _invocation): def original_tool_handler(args): return f"ORIGINAL_SHOULD_NOT_RUN_{args.get('value', '')}" - suspended_client = CopilotClient(ExternalServerConfig(url=cli_url)) + suspended_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) session1 = await suspended_client.create_session( on_permission_request=hold_permission, tools=[_make_pending_tool("resume_permission_tool", original_tool_handler)], @@ -171,7 +174,11 @@ def resumed_tool_handler(args): resumed_tool_invoked = True return f"PERMISSION_RESUMED_{args['value'].upper()}" - resumed_client = CopilotClient(ExternalServerConfig(url=cli_url)) + resumed_client = CopilotClient( + ExternalServerConfig( + url=cli_url, tcp_connection_token="py-tcp-shared-test-token" + ) + ) try: session2 = await resumed_client.resume_session( session_id, @@ -226,7 +233,9 @@ async def blocking_external_tool(args): tool_started.set_result(value) return await release_original - suspended_client = CopilotClient(ExternalServerConfig(url=cli_url)) + suspended_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) session1 = await suspended_client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[_make_pending_tool("resume_external_tool", blocking_external_tool)], @@ -245,7 +254,11 @@ async def blocking_external_tool(args): await suspended_client.force_stop() - resumed_client = CopilotClient(ExternalServerConfig(url=cli_url)) + resumed_client = CopilotClient( + ExternalServerConfig( + url=cli_url, tcp_connection_token="py-tcp-shared-test-token" + ) + ) try: session2 = await resumed_client.resume_session( session_id, @@ -298,7 +311,9 @@ async def tool_b(args): tool_b_started.set_result(args["value"]) return await release_b - suspended_client = CopilotClient(ExternalServerConfig(url=cli_url)) + suspended_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) session1 = await suspended_client.create_session( on_permission_request=PermissionHandler.approve_all, tools=[ @@ -327,7 +342,11 @@ async def tool_b(args): await suspended_client.force_stop() - resumed_client = CopilotClient(ExternalServerConfig(url=cli_url)) + resumed_client = CopilotClient( + ExternalServerConfig( + url=cli_url, tcp_connection_token="py-tcp-shared-test-token" + ) + ) try: session2 = await resumed_client.resume_session( session_id, @@ -376,7 +395,9 @@ async def test_should_resume_successfully_when_no_pending_work_exists( try: cli_url = f"localhost:{server.actual_port}" - first_client = CopilotClient(ExternalServerConfig(url=cli_url)) + first_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) try: first_session = await first_client.create_session( on_permission_request=PermissionHandler.approve_all, @@ -390,7 +411,9 @@ async def test_should_resume_successfully_when_no_pending_work_exists( finally: await _safe_force_stop(first_client) - resumed_client = CopilotClient(ExternalServerConfig(url=cli_url)) + resumed_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) try: resumed_session = await resumed_client.resume_session( session_id, diff --git a/python/e2e/test_suspend_e2e.py b/python/e2e/test_suspend_e2e.py index 37587baff..e87659d93 100644 --- a/python/e2e/test_suspend_e2e.py +++ b/python/e2e/test_suspend_e2e.py @@ -37,6 +37,7 @@ def _make_subprocess_client(ctx: E2ETestContext, *, use_stdio: bool = True) -> C env=ctx.get_env(), github_token=github_token, use_stdio=use_stdio, + tcp_connection_token="py-tcp-shared-test-token", ) ) @@ -101,7 +102,9 @@ async def test_should_allow_resume_and_continue_conversation_after_suspend( cli_url = f"localhost:{server.actual_port}" session_id: str - first_client = CopilotClient(ExternalServerConfig(url=cli_url)) + first_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) try: session1 = await first_client.create_session( on_permission_request=PermissionHandler.approve_all @@ -116,7 +119,9 @@ async def test_should_allow_resume_and_continue_conversation_after_suspend( finally: await _safe_force_stop(first_client) - resumed_client = CopilotClient(ExternalServerConfig(url=cli_url)) + resumed_client = CopilotClient( + ExternalServerConfig(url=cli_url, tcp_connection_token="py-tcp-shared-test-token") + ) try: session2 = await resumed_client.resume_session( session_id, diff --git a/python/e2e/test_ui_elicitation_multi_client_e2e.py b/python/e2e/test_ui_elicitation_multi_client_e2e.py index 4daf3df7d..8da62f3de 100644 --- a/python/e2e/test_ui_elicitation_multi_client_e2e.py +++ b/python/e2e/test_ui_elicitation_multi_client_e2e.py @@ -69,6 +69,7 @@ async def setup(self): env=self._get_env(), use_stdio=False, github_token=github_token, + tcp_connection_token="py-tcp-shared-test-token", ) ) @@ -81,7 +82,12 @@ async def setup(self): self._actual_port = self._client1.actual_port assert self._actual_port is not None - self._client2 = CopilotClient(ExternalServerConfig(url=f"localhost:{self._actual_port}")) + self._client2 = CopilotClient( + ExternalServerConfig( + url=f"localhost:{self._actual_port}", + tcp_connection_token="py-tcp-shared-test-token", + ) + ) async def teardown(self, test_failed: bool = False): for c in (self._client2, self._client1): @@ -132,7 +138,12 @@ def _get_env(self) -> dict: def make_external_client(self) -> CopilotClient: """Create a new external client connected to the same CLI server.""" assert self._actual_port is not None - return CopilotClient(ExternalServerConfig(url=f"localhost:{self._actual_port}")) + return CopilotClient( + ExternalServerConfig( + url=f"localhost:{self._actual_port}", + tcp_connection_token="py-tcp-shared-test-token", + ) + ) @property def client1(self) -> CopilotClient: diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index cf3247e8b..f43d08c89 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -984,6 +984,16 @@ function emitRpcClass( resolveObjectSchema(schema, rpcDefinitions) ?? resolveSchema(schema, rpcDefinitions) ?? schema; + // Visibility is driven by the JSON Schema definition itself (set via + // `.asInternal()` on the originating Zod schema). The runtime schema + // generator enforces that no public method references an internal type, + // so it's safe to upgrade callers' default to internal here. + if ( + (schema as Record).visibility === "internal" || + (effectiveSchema as Record).visibility === "internal" + ) { + visibility = "internal"; + } const schemaKey = stableStringify(effectiveSchema); const existingSchema = emittedRpcClassSchemas.get(className); if (existingSchema) { @@ -1170,13 +1180,15 @@ function emitServerInstanceMethod( groupDeprecated: boolean ): void { const methodName = toPascalCase(name); + const isInternal = method.visibility === "internal"; + const methodVisibility = isInternal ? "internal" : "public"; const resultSchema = getMethodResultSchema(method); let resultClassName = !isVoidSchema(resultSchema) ? resultTypeName(method) : ""; if (!isVoidSchema(resultSchema) && method.stability === "experimental") { experimentalRpcTypes.add(resultClassName); } if (isObjectSchema(resultSchema)) { - const resultClass = emitRpcClass(resultClassName, resultSchema!, "public", classes); + const resultClass = emitRpcClass(resultClassName, resultSchema!, methodVisibility, classes); if (resultClass) classes.push(resultClass); } else if (!isVoidSchema(resultSchema)) { resultClassName = emitNonObjectResultType(resultClassName, resultSchema!, classes); @@ -1228,7 +1240,7 @@ function emitServerInstanceMethod( sigParams.push("CancellationToken cancellationToken = default"); const taskType = !isVoidSchema(resultSchema) ? `Task<${resultClassName}>` : "Task"; - lines.push(`${indent}public async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); + lines.push(`${indent}${methodVisibility} async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); lines.push(`${indent}{`); if (requestClassName && bodyAssignments.length > 0) { lines.push(`${indent} var request = new ${requestClassName} { ${bodyAssignments.join(", ")} };`); @@ -1276,13 +1288,15 @@ function emitSessionRpcClasses(node: Record, classes: string[]) function emitSessionMethod(key: string, method: RpcMethod, lines: string[], classes: string[], indent: string, groupExperimental: boolean, groupDeprecated: boolean): void { const methodName = toPascalCase(key); + const isInternal = method.visibility === "internal"; + const methodVisibility = isInternal ? "internal" : "public"; const resultSchema = getMethodResultSchema(method); let resultClassName = !isVoidSchema(resultSchema) ? resultTypeName(method) : ""; if (!isVoidSchema(resultSchema) && method.stability === "experimental") { experimentalRpcTypes.add(resultClassName); } if (isObjectSchema(resultSchema)) { - const resultClass = emitRpcClass(resultClassName, resultSchema!, "public", classes); + const resultClass = emitRpcClass(resultClassName, resultSchema!, methodVisibility, classes); if (resultClass) classes.push(resultClass); } else if (!isVoidSchema(resultSchema)) { resultClassName = emitNonObjectResultType(resultClassName, resultSchema!, classes); @@ -1328,7 +1342,7 @@ function emitSessionMethod(key: string, method: RpcMethod, lines: string[], clas sigParams.push("CancellationToken cancellationToken = default"); const taskType = !isVoidSchema(resultSchema) ? `Task<${resultClassName}>` : "Task"; - lines.push(`${indent}public async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); + lines.push(`${indent}${methodVisibility} async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); lines.push(`${indent}{`, `${indent} var request = new ${requestClassName} { ${bodyAssignments.join(", ")} };`); if (!isVoidSchema(resultSchema)) { lines.push(`${indent} return await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`, `${indent}}`); diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index d488ab0ed..d75c568df 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -13,6 +13,7 @@ import { FetchingJSONSchemaStore, InputData, JSONSchemaInput, quicktype } from " import { promisify } from "util"; import { cloneSchemaForCodegen, + filterNodeByVisibility, fixNullableRequiredRefsInApiSchema, getApiSchemaPath, getRpcSchemaTypeName, @@ -25,6 +26,7 @@ import { getNullableInner, isRpcMethod, postProcessSchema, + stripBooleanLiterals, writeGeneratedFile, collectDefinitionCollections, resolveObjectSchema, @@ -1216,7 +1218,7 @@ async function generateRpc(schemaPath?: string): Promise { const singleSchema: JSONSchema7 = { $schema: "http://json-schema.org/draft-07/schema#", type: "object", - definitions: allDefinitions as Record, + definitions: stripBooleanLiterals(allDefinitions) as Record, properties: Object.fromEntries( Object.keys(allDefinitions).map((name) => [name, { $ref: `#/definitions/${name}` }]) ), @@ -1292,6 +1294,21 @@ async function generateRpc(schemaPath?: string): Promise { `// Deprecated: ${typeName} is deprecated and will be removed in a future version.\n$1` ); } + + // Annotate internal data types (driven by the JSON Schema definition's + // `visibility: "internal"` flag, set via `.asInternal()` on the Zod source). + const internalTypeNames = new Set(); + for (const [name, def] of Object.entries(allDefinitions)) { + if (def && typeof def === "object" && (def as Record).visibility === "internal") { + internalTypeNames.add(name); + } + } + for (const typeName of internalTypeNames) { + qtCode = qtCode.replace( + new RegExp(`^(type ${typeName} struct)`, "m"), + `// Internal: ${typeName} is an internal SDK API and is not part of the public surface.\n$1` + ); + } // Remove trailing blank lines from quicktype output before appending qtCode = qtCode.replace(/\n+$/, ""); // Replace interface{} with any (quicktype emits the pre-1.18 form) @@ -1334,12 +1351,18 @@ async function generateRpc(schemaPath?: string): Promise { // Emit ServerRpc if (schema.server) { - emitRpcWrapper(lines, schema.server, false, resolveType, fieldNames); + const publicNode = filterNodeByVisibility(schema.server, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, false, resolveType, fieldNames, ""); + const internalNode = filterNodeByVisibility(schema.server, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, false, resolveType, fieldNames, "Internal"); } // Emit SessionRpc if (schema.session) { - emitRpcWrapper(lines, schema.session, true, resolveType, fieldNames); + const publicNode = filterNodeByVisibility(schema.session, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, true, resolveType, fieldNames, ""); + const internalNode = filterNodeByVisibility(schema.session, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, true, resolveType, fieldNames, "Internal"); } if (schema.clientSession) { @@ -1395,13 +1418,17 @@ function emitApiGroup( } } -function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>): void { +function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>, classPrefix: string = ""): void { const groups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v)); - const wrapperName = isSession ? "SessionRpc" : "ServerRpc"; + const wrapperName = classPrefix + (isSession ? "SessionRpc" : "ServerRpc"); const apiSuffix = "Api"; - const serviceName = isSession ? "sessionApi" : "serverApi"; + // Lowercase the prefix so the unexported service struct stays unexported in Go. + const prefixLower = classPrefix ? classPrefix.charAt(0).toLowerCase() + classPrefix.slice(1) : ""; + const serviceName = prefixLower + ? prefixLower + (isSession ? "SessionApi" : "ServerApi") + : (isSession ? "sessionApi" : "serverApi"); // Emit the common service struct (unexported, shared by all API groups via type cast) lines.push(`type ${serviceName} struct {`); @@ -1412,7 +1439,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio // Emit API types for groups for (const [groupName, groupNode] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); const apiName = prefix + toPascalCase(groupName) + apiSuffix; const groupExperimental = isNodeFullyExperimental(groupNode as Record); const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); @@ -1426,12 +1453,14 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio const pad = (name: string) => name.padEnd(maxFieldLen); // Emit wrapper struct - lines.push(`// ${wrapperName} provides typed ${isSession ? "session" : "server"}-scoped RPC methods.`); + lines.push(classPrefix === "Internal" + ? `// ${wrapperName} provides internal SDK ${isSession ? "session" : "server"}-scoped RPC methods (handshake helpers etc.). Not part of the public API.` + : `// ${wrapperName} provides typed ${isSession ? "session" : "server"}-scoped RPC methods.`); lines.push(`type ${wrapperName} struct {`); lines.push(`\t${pad("common")} ${serviceName} // Reuse a single struct instead of allocating one for each service on the heap.`); lines.push(``); for (const [groupName] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); lines.push(`\t${pad(toPascalCase(groupName))} *${prefix}${toPascalCase(groupName)}${apiSuffix}`); } lines.push(`}`); @@ -1453,7 +1482,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio lines.push(`\tr.common = ${serviceName}{client: client}`); } for (const [groupName] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); lines.push(`\tr.${toPascalCase(groupName)} = (*${prefix}${toPascalCase(groupName)}${apiSuffix})(&r.common)`); } lines.push(`\treturn r`); @@ -1486,6 +1515,9 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc if (method.stability === "experimental" && !groupExperimental) { lines.push(`// Experimental: ${methodName} is an experimental API and may change or be removed in future versions.`); } + if (method.visibility === "internal") { + lines.push(`// Internal: ${methodName} is part of the SDK's internal handshake/plumbing; external callers should not use it.`); + } const sig = hasParams ? `func (a *${receiver}) ${methodName}(ctx context.Context, params *${paramsType}) (*${resultType}, error)` : `func (a *${receiver}) ${methodName}(ctx context.Context) (*${resultType}, error)`; diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index a4cbcd6d4..f9327f9d8 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -12,6 +12,7 @@ import type { JSONSchema7 } from "json-schema"; import { fileURLToPath } from "url"; import { cloneSchemaForCodegen, + filterNodeByVisibility, fixNullableRequiredRefsInApiSchema, getApiSchemaPath, getRpcSchemaTypeName, @@ -24,6 +25,7 @@ import { isNodeFullyDeprecated, isSchemaDeprecated, postProcessSchema, + stripBooleanLiterals, writeGeneratedFile, collectDefinitionCollections, hasSchemaPayload, @@ -1677,7 +1679,7 @@ async function generateRpc(schemaPath?: string): Promise { const singleSchema: Record = { $schema: "http://json-schema.org/draft-07/schema#", type: "object", - definitions: allDefinitions, + definitions: stripBooleanLiterals(allDefinitions), properties: Object.fromEntries( Object.keys(allDefinitions).map((name) => [name, { $ref: `#/definitions/${name}` }]) ), @@ -1774,6 +1776,21 @@ async function generateRpc(schemaPath?: string): Promise { ); } + // Annotate internal data types (driven by the JSON Schema definition's + // `visibility: "internal"` flag, set via `.asInternal()` on the Zod source). + const internalTypeNames = new Set(); + for (const [name, def] of Object.entries(allDefinitions)) { + if (def && typeof def === "object" && (def as Record).visibility === "internal") { + internalTypeNames.add(name); + } + } + for (const typeName of internalTypeNames) { + typesCode = typesCode.replace( + new RegExp(`^(@dataclass\\n)?class ${typeName}[:(]`, "m"), + (match) => `# Internal: this type is an internal SDK API and is not part of the public surface.\n${match}` + ); + } + // Extract actual class names generated by quicktype (may differ from toPascalCase, // e.g. quicktype produces "SessionMCPList" not "SessionMcpList") const actualTypeNames = new Map(); @@ -1841,10 +1858,16 @@ def _patch_model_capabilities(data: dict) -> dict: // Emit RPC wrapper classes if (schema.server) { - emitRpcWrapper(lines, schema.server, false, resolveType); + const publicNode = filterNodeByVisibility(schema.server, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, false, resolveType, ""); + const internalNode = filterNodeByVisibility(schema.server, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, false, resolveType, "_Internal"); } if (schema.session) { - emitRpcWrapper(lines, schema.session, true, resolveType); + const publicNode = filterNodeByVisibility(schema.session, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, true, resolveType, ""); + const internalNode = filterNodeByVisibility(schema.session, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, true, resolveType, "_Internal"); } if (schema.clientSession) { emitClientSessionApiRegistration(lines, schema.clientSession, resolveType); @@ -1875,7 +1898,8 @@ function emitPyApiGroup( isSession: boolean, resolveType: (name: string) => string, groupExperimental: boolean, - groupDeprecated: boolean = false + groupDeprecated: boolean = false, + classPrefix: string = "" ): void { const subGroups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); @@ -1884,7 +1908,7 @@ function emitPyApiGroup( const subApiName = apiName.replace(/Api$/, "") + toPascalCase(subGroupName) + "Api"; const subGroupExperimental = isNodeFullyExperimental(subGroupNode as Record); const subGroupDeprecated = isNodeFullyDeprecated(subGroupNode as Record); - emitPyApiGroup(lines, subApiName, subGroupNode as Record, isSession, resolveType, subGroupExperimental, subGroupDeprecated); + emitPyApiGroup(lines, subApiName, subGroupNode as Record, isSession, resolveType, subGroupExperimental, subGroupDeprecated, classPrefix); } // Emit this class @@ -1920,38 +1944,42 @@ function emitPyApiGroup( lines.push(``); } -function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string): void { +function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, classPrefix: string = ""): void { const groups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v)); - const wrapperName = isSession ? "SessionRpc" : "ServerRpc"; + const wrapperName = classPrefix + (isSession ? "SessionRpc" : "ServerRpc"); // Emit API classes for groups (recursively handles sub-groups) for (const [groupName, groupNode] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); const apiName = prefix + toPascalCase(groupName) + "Api"; const groupExperimental = isNodeFullyExperimental(groupNode as Record); const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); - emitPyApiGroup(lines, apiName, groupNode as Record, isSession, resolveType, groupExperimental, groupDeprecated); + emitPyApiGroup(lines, apiName, groupNode as Record, isSession, resolveType, groupExperimental, groupDeprecated, classPrefix); } // Emit wrapper class if (isSession) { lines.push(`class ${wrapperName}:`); - lines.push(` """Typed session-scoped RPC methods."""`); + lines.push(classPrefix === "_Internal" + ? ` """Internal SDK session-scoped RPC methods. Not part of the public API."""` + : ` """Typed session-scoped RPC methods."""`); lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); lines.push(` self._client = client`); lines.push(` self._session_id = session_id`); for (const [groupName] of groups) { - lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id)`); + lines.push(` self.${toSnakeCase(groupName)} = ${classPrefix}${toPascalCase(groupName)}Api(client, session_id)`); } } else { lines.push(`class ${wrapperName}:`); - lines.push(` """Typed server-scoped RPC methods."""`); + lines.push(classPrefix === "_Internal" + ? ` """Internal SDK server-scoped RPC methods (handshake helpers etc.). Not part of the public API."""` + : ` """Typed server-scoped RPC methods."""`); lines.push(` def __init__(self, client: "JsonRpcClient"):`); lines.push(` self._client = client`); for (const [groupName] of groups) { - lines.push(` self.${toSnakeCase(groupName)} = Server${toPascalCase(groupName)}Api(client)`); + lines.push(` self.${toSnakeCase(groupName)} = ${classPrefix}Server${toPascalCase(groupName)}Api(client)`); } } lines.push(``); @@ -2005,6 +2033,9 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession: if (method.stability === "experimental" && !groupExperimental) { lines.push(` """.. warning:: This API is experimental and may change or be removed in future versions."""`); } + if (method.visibility === "internal") { + lines.push(` """:meta private: Internal SDK API; not part of the public surface."""`); + } // Deserialize helper const innerTypeName = hasNullableResult ? resolveType(pythonResultTypeName(method, nullableInner)) : resultType; diff --git a/scripts/codegen/typescript.ts b/scripts/codegen/typescript.ts index d032c34fd..5fdb829ee 100644 --- a/scripts/codegen/typescript.ts +++ b/scripts/codegen/typescript.ts @@ -338,6 +338,16 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; const experimentalTypes = new Set(); // Track which type names come from deprecated methods for JSDoc annotations. const deprecatedTypes = new Set(); + // Types are tagged @internal directly via `visibility: "internal"` on the JSON Schema + // definition (set by `.asInternal()` on the originating Zod schema). The runtime + // schema generator enforces that no public method references an internal type, so + // there's no transitive propagation to do here. + const internalTypes = new Set(); + for (const [name, def] of Object.entries(combinedSchema.definitions ?? {})) { + if (def && typeof def === "object" && (def as Record).visibility === "internal") { + internalTypes.add(name); + } + } for (const method of [...allMethods, ...clientSessionMethods]) { const resultSchema = getMethodResultSchema(method); @@ -425,29 +435,75 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; `$1/** @deprecated */\n$2` ); } + // Add @internal JSDoc annotations for types from internal methods + for (const intType of internalTypes) { + annotatedTs = annotatedTs.replace( + new RegExp(`(^|\\n)(export (?:interface|type) ${intType}\\b)`, "m"), + `$1/** @internal */\n$2` + ); + } lines.push(annotatedTs); lines.push(""); } // Generate factory functions +function hasInternalMethods(node: Record): boolean { + for (const value of Object.values(node)) { + if (isRpcMethod(value)) { + if ((value as RpcMethod).visibility === "internal") return true; + } else if (typeof value === "object" && value !== null) { + if (hasInternalMethods(value as Record)) return true; + } + } + return false; +} + if (schema.server) { lines.push(`/** Create typed server-scoped RPC methods (no session required). */`); lines.push(`export function createServerRpc(connection: MessageConnection) {`); lines.push(` return {`); - lines.push(...emitGroup(schema.server, " ", false)); + lines.push(...emitGroup(schema.server, " ", false, false, false, "public")); lines.push(` };`); lines.push(`}`); lines.push(""); + + if (hasInternalMethods(schema.server)) { + lines.push(`/**`); + lines.push(` * Create typed server-scoped RPC methods that are part of the SDK's internal`); + lines.push(` * surface (e.g. handshake helpers). Not exported on the public client API.`); + lines.push(` * @internal`); + lines.push(` */`); + lines.push(`export function createInternalServerRpc(connection: MessageConnection) {`); + lines.push(` return {`); + lines.push(...emitGroup(schema.server, " ", false, false, false, "internal")); + lines.push(` };`); + lines.push(`}`); + lines.push(""); + } } if (schema.session) { lines.push(`/** Create typed session-scoped RPC methods. */`); lines.push(`export function createSessionRpc(connection: MessageConnection, sessionId: string) {`); lines.push(` return {`); - lines.push(...emitGroup(schema.session, " ", true)); + lines.push(...emitGroup(schema.session, " ", true, false, false, "public")); lines.push(` };`); lines.push(`}`); lines.push(""); + + if (hasInternalMethods(schema.session)) { + lines.push(`/**`); + lines.push(` * Create typed session-scoped RPC methods that are part of the SDK's internal`); + lines.push(` * surface. Not exported on the public client API.`); + lines.push(` * @internal`); + lines.push(` */`); + lines.push(`export function createInternalSessionRpc(connection: MessageConnection, sessionId: string) {`); + lines.push(` return {`); + lines.push(...emitGroup(schema.session, " ", true, false, false, "internal")); + lines.push(` };`); + lines.push(`}`); + lines.push(""); + } } // Generate client session API handler interfaces and registration function @@ -459,10 +515,20 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; console.log(` ✓ ${outPath}`); } -function emitGroup(node: Record, indent: string, isSession: boolean, parentExperimental = false, parentDeprecated = false): string[] { +function emitGroup( + node: Record, + indent: string, + isSession: boolean, + parentExperimental = false, + parentDeprecated = false, + visibilityFilter?: "public" | "internal", +): string[] { const lines: string[] = []; for (const [key, value] of Object.entries(node)) { if (isRpcMethod(value)) { + const isInternalMethod = (value as RpcMethod).visibility === "internal"; + if (visibilityFilter === "public" && isInternalMethod) continue; + if (visibilityFilter === "internal" && !isInternalMethod) continue; const { rpcMethod, params } = value; const resultType = tsResultType(value); const paramsType = paramsTypeName(value); @@ -508,6 +574,16 @@ function emitGroup(node: Record, indent: string, isSession: boo } else if (typeof value === "object" && value !== null) { const groupExperimental = isNodeFullyExperimental(value as Record); const groupDeprecated = isNodeFullyDeprecated(value as Record); + const childLines = emitGroup( + value as Record, + indent + " ", + isSession, + groupExperimental, + groupDeprecated, + visibilityFilter, + ); + // Skip the wrapper if the visibility filter dropped every method in this subtree. + if (childLines.length === 0) continue; if (groupDeprecated) { lines.push(`${indent}/** @deprecated */`); } @@ -515,7 +591,7 @@ function emitGroup(node: Record, indent: string, isSession: boo lines.push(`${indent}/** @experimental */`); } lines.push(`${indent}${key}: {`); - lines.push(...emitGroup(value as Record, indent + " ", isSession, groupExperimental, groupDeprecated)); + lines.push(...childLines); lines.push(`${indent}},`); } } diff --git a/scripts/codegen/utils.ts b/scripts/codegen/utils.ts index 16aaa0bfe..bbbeb877c 100644 --- a/scripts/codegen/utils.ts +++ b/scripts/codegen/utils.ts @@ -134,6 +134,38 @@ export function postProcessSchema(schema: JSONSchema7): JSONSchema7 { return processed; } +/** + * Strip boolean literal constraints (`const: true/false`, `enum: [true]`, `enum: [false]`) + * from a schema, recursively. quicktype's Python and Go renderers attempt to derive + * identifier names from enum values; deriving a name from a boolean throws inside + * `snakeNameStyle` (TypeError: s.codePointAt is not a function). + * + * The literal narrowing isn't expressible in Python/Go anyway, so we drop it and + * keep just `type: "boolean"`. TypeScript/C# codegen runs on the original schema. + */ +export function stripBooleanLiterals(schema: T): T { + if (typeof schema !== "object" || schema === null) return schema; + if (Array.isArray(schema)) { + return schema.map((item) => stripBooleanLiterals(item)) as unknown as T; + } + const result: Record = {}; + const src = schema as unknown as Record; + const isBooleanType = src.type === "boolean"; + for (const [key, value] of Object.entries(src)) { + if (isBooleanType && key === "const" && typeof value === "boolean") continue; + if ( + isBooleanType && + key === "enum" && + Array.isArray(value) && + value.every((v) => typeof v === "boolean") + ) { + continue; + } + result[key] = stripBooleanLiterals(value); + } + return result as T; +} + /** * Normalize schema defects where a required property with a `$ref` to an object type * has a description explicitly mentioning "null" as a valid value. @@ -222,6 +254,7 @@ export interface RpcMethod { params: JSONSchema7 | null; result: JSONSchema7 | null; stability?: string; + visibility?: string; deprecated?: boolean; } @@ -380,6 +413,33 @@ export function isNodeFullyDeprecated(node: Record): boolean { return methods.length > 0 && methods.every(m => m.deprecated === true); } +/** + * Returns a filtered copy of an API tree containing only methods whose visibility + * matches `keep`. Sub-groups that end up empty are pruned. Returns null if nothing + * survives the filter. + * + * `"public"` keeps methods without `visibility === "internal"`. + * `"internal"` keeps methods with `visibility === "internal"`. + */ +export function filterNodeByVisibility( + node: Record, + keep: "public" | "internal", +): Record | null { + const result: Record = {}; + for (const [key, value] of Object.entries(node)) { + if (isRpcMethod(value)) { + const isInternal = (value as RpcMethod).visibility === "internal"; + if (keep === "public" && isInternal) continue; + if (keep === "internal" && !isInternal) continue; + result[key] = value; + } else if (typeof value === "object" && value !== null) { + const sub = filterNodeByVisibility(value as Record, keep); + if (sub) result[key] = sub; + } + } + return Object.keys(result).length === 0 ? null : result; +} + /** Returns true when a JSON Schema node is marked as deprecated. */ export function isSchemaDeprecated(schema: JSONSchema7 | null | undefined): boolean { return typeof schema === "object" && schema !== null && (schema as Record).deprecated === true;