Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cs/src/Connections/SshSessionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System;
using Microsoft.DevTunnels.Ssh;

namespace Microsoft.DevTunnels.Connections;

internal static class SshSessionExtensions
{
public static string GetShortSessionId(this SshSession session)
{
if (session.SessionId == null || session.SessionId.Length < 16)
{
return string.Empty;
}

return new Guid(session.SessionId.AsSpan(0, 16)).ToString();
}
}
3 changes: 2 additions & 1 deletion cs/src/Connections/TunnelClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public override async Task ConnectAsync(
Requires.NotNull(tunnel, nameof(tunnel));
Requires.NotNull(tunnel.Endpoints!, nameof(Tunnel.Endpoints));

if (this.SshSession != null)
if (this.SshSession?.IsConnected == true)
{
throw new InvalidOperationException(
"Already connected. Use separate instances to connect to multiple tunnels.");
Expand Down Expand Up @@ -240,6 +240,7 @@ protected async Task StartSshSessionAsync(Stream stream, TunnelConnectionOptions

private void OnSshSessionDisconnected(object? sender, EventArgs e) =>
MaybeStartReconnecting(
(SshSession)sender!,
SshDisconnectReason.ConnectionLost,
exception: new SshConnectionException("Connection lost.", SshDisconnectReason.ConnectionLost));

Expand Down
20 changes: 0 additions & 20 deletions cs/src/Connections/TunnelConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public abstract class TunnelConnection : IAsyncDisposable
{
private readonly CancellationTokenSource disposeCts = new();
private ConnectionStatus connectionStatus;
private Stopwatch connectionTimer = new();
private Tunnel? tunnel;

/// <summary>
Expand Down Expand Up @@ -423,25 +422,6 @@ protected virtual void OnConnectionStatusChanged(
ConnectionStatus previousConnectionStatus,
ConnectionStatus connectionStatus)
{
TimeSpan duration = this.connectionTimer.Elapsed;
this.connectionTimer.Restart();

if (Tunnel != null)
{
var statusEvent = new TunnelEvent($"{ConnectionRole}_connection_status");
statusEvent.Properties = new Dictionary<string, string>
{
[nameof(ConnectionStatus)] = connectionStatus.ToString(),
[$"Previous{nameof(ConnectionStatus)}"] = previousConnectionStatus.ToString(),
};
if (previousConnectionStatus != ConnectionStatus.None)
{
statusEvent.Properties[$"{previousConnectionStatus}Duration"] = duration.ToString();
}

ManagementClient?.ReportEvent(Tunnel, statusEvent);
}

var handler = ConnectionStatusChanged;
if (handler != null)
{
Expand Down
78 changes: 74 additions & 4 deletions cs/src/Connections/TunnelRelayConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// </copyright>

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
Expand Down Expand Up @@ -55,8 +56,10 @@ public abstract class TunnelRelayConnection : TunnelConnection, IRelayClient, IP
/// </summary>
public const int RetryMaxDelayMs = 12_800;

private string? websocketRequestId = null;
private TunnelConnectionOptions? connectionOptions;
private Task? reconnectTask;
private Stopwatch connectionTimer = new();

/// <summary>
/// Create a new instance of <see cref="TunnelRelayConnection"/> class.
Expand All @@ -66,6 +69,12 @@ protected TunnelRelayConnection(ITunnelManagementClient? managementClient, Trace
{
}

/// <summary>
/// Gets an ID that is unique to this instance of <see cref="TunnelRelayConnection"/>,
/// useful for correlating connection events over time.
/// </summary>
protected virtual string ConnectionId { get; } = Guid.NewGuid().ToString();

/// <summary>
/// Connection protocol used to connect to Relay.
/// </summary>
Expand Down Expand Up @@ -170,11 +179,60 @@ protected async Task ConnectTunnelSessionAsync(
}
}

/// <summary>
/// Event fired when the connection status has changed.
/// </summary>
protected override void OnConnectionStatusChanged(
ConnectionStatus previousConnectionStatus,
ConnectionStatus connectionStatus)
{
TimeSpan duration = this.connectionTimer.Elapsed;
this.connectionTimer.Restart();

if (Tunnel != null && ManagementClient != null)
{
var statusEvent = new TunnelEvent($"{ConnectionRole}_connection_status");
statusEvent.Properties = new Dictionary<string, string>
{
[nameof(ConnectionStatus)] = connectionStatus.ToString(),
[$"Previous{nameof(ConnectionStatus)}"] = previousConnectionStatus.ToString(),
};

if (previousConnectionStatus != ConnectionStatus.None)
{
statusEvent.Properties[$"{previousConnectionStatus}Duration"] = duration.ToString();
}

if (IsClientConnection)
{
// For client sessions, report the SSH session ID property, which is derived from
// the SSH key-exchange such that both host and client have the same ID.
statusEvent.Properties["ClientSessionId"] = SshSession?.GetShortSessionId() ?? string.Empty;
}
else
{
// For host sessions, there is no SSH encryption or key-exchange.
// Just use a locally-generated GUID that is unique to this session.
statusEvent.Properties["HostSessionId"] = ConnectionId;
}

if (this.websocketRequestId != null)
{
statusEvent.Properties["WebsocketRequestId"] = this.websocketRequestId;
}

ManagementClient.ReportEvent(Tunnel, statusEvent);
}

base.OnConnectionStatusChanged(previousConnectionStatus, connectionStatus);
}

/// <summary>
/// Start reconnecting if connected, not reconnecting already,
/// and <paramref name="reason"/> is <see cref="SshDisconnectReason.ConnectionLost"/>.
/// </summary>
protected void MaybeStartReconnecting(
SshSession session,
SshDisconnectReason reason,
string? message = null,
Exception? exception = null)
Expand Down Expand Up @@ -214,6 +272,10 @@ protected void MaybeStartReconnecting(
var reconnectEvent = new TunnelEvent($"{ConnectionRole}_reconnect");
reconnectEvent.Severity = TunnelEvent.Warning;
reconnectEvent.Details = exception?.ToString() ?? traceMessage;
reconnectEvent.Properties = new Dictionary<string, string>
{
["ClientSessionId"] = session.GetShortSessionId(),
};
ManagementClient?.ReportEvent(Tunnel, reconnectEvent);
}

Expand All @@ -228,6 +290,10 @@ protected void MaybeStartReconnecting(
var disconnectEvent = new TunnelEvent($"{ConnectionRole}_disconnect");
disconnectEvent.Severity = TunnelEvent.Warning;
disconnectEvent.Details = exception?.ToString() ?? traceMessage;
disconnectEvent.Properties = new Dictionary<string, string>
{
["ClientSessionId"] = session.GetShortSessionId(),
};
ManagementClient?.ReportEvent(Tunnel, disconnectEvent);
}

Expand Down Expand Up @@ -271,6 +337,8 @@ protected virtual async Task<Stream> CreateSessionStreamAsync(CancellationToken
cancellation);
Trace.TraceEvent(TraceEventType.Verbose, 0, "Connected with subprotocol '{0}'", subprotocol);

this.websocketRequestId = (stream as WebSocketStream)?.RequestId;

if (this.IsClientConnection)
{
this.OnReportProgress(Progress.OpenedClientConnectionToRelay);
Expand Down Expand Up @@ -329,7 +397,6 @@ protected virtual async Task CloseSessionAsync(
return;
}

SshSession = null;
UnsubscribeSessionEvents(session);
if (!session.IsClosed && session.IsConnected)
{
Expand All @@ -343,6 +410,11 @@ protected virtual async Task CloseSessionAsync(
}
}

// Set the connection status to disconnected before setting SshSession to null,
// so the session ID can be reported in the disconnect event.
ConnectionStatus = ConnectionStatus.Disconnected;
SshSession = null;

// Closing the SSH session does nothing if the session is in disconnected state,
// which may happen for a reconnectable session when the connection drops.
// Disposing of the session forces closing and frees up the resources.
Expand All @@ -355,9 +427,7 @@ protected virtual async Task CloseSessionAsync(
protected virtual void OnSshSessionClosed(object? sender, SshSessionClosedEventArgs e)
{
var session = (SshClientSession)sender!;
UnsubscribeSessionEvents(session);
SshSession = null;
MaybeStartReconnecting(e.Reason, e.Message, e.Exception);
MaybeStartReconnecting(session, e.Reason, e.Message, e.Exception);
}

/// <summary>
Expand Down
13 changes: 12 additions & 1 deletion cs/src/Connections/TunnelRelayTunnelHost.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public TunnelRelayTunnelHost(ITunnelManagementClient managementClient, TraceSour
this.hostId = MultiModeTunnelHost.HostId;
}

/// <inheritdoc/>
protected override string ConnectionId => this.hostId;

/// <summary>
/// Get or set synthetic endpoint signature for the endpoint created for the host
/// when connecting.
Expand Down Expand Up @@ -426,6 +429,8 @@ private async Task ConnectAndRunClientSessionAsync(SshStream stream, Cancellatio
connectedEvent.Properties = new Dictionary<string, string>
{
["ClientChannelId"] = channelId.ToString(),
["ClientSessionId"] = session.GetShortSessionId(),
["HostSessionId"] = ConnectionId,
};
ManagementClient?.ReportEvent(Tunnel, connectedEvent);
}
Expand Down Expand Up @@ -456,12 +461,15 @@ private async Task ConnectAndRunClientSessionAsync(SshStream stream, Cancellatio

async void OnSshClientReconnected(object? sender, EventArgs e)
{
var session = (SshSession)sender!;
if (Tunnel != null)
{
var reconnectedEvent = new TunnelEvent($"host_client_reconnect");
reconnectedEvent.Properties = new Dictionary<string, string>
{
["ClientChannelId"] = channelId.ToString(),
["ClientSessionId"] = session.GetShortSessionId(),
["HostSessionId"] = ConnectionId,
};
ManagementClient?.ReportEvent(Tunnel, reconnectedEvent);
}
Expand All @@ -473,7 +481,8 @@ await StartForwardingExistingPortsAsync(

void OnClientSessionClosed(object? sender, SshSessionClosedEventArgs e)
{
TraceSource trace = ((SshSession)sender!).Trace;
var session = (SshSession)sender!;
var trace = session.Trace;
string? details = null;
string? severity = null;

Expand Down Expand Up @@ -510,6 +519,8 @@ void OnClientSessionClosed(object? sender, SshSessionClosedEventArgs e)
disconnectedEvent.Properties = new Dictionary<string, string>
{
["ClientChannelId"] = channelId.ToString(),
["ClientSessionId"] = session.GetShortSessionId(),
["HostSessionId"] = ConnectionId,
};
ManagementClient?.ReportEvent(Tunnel, disconnectedEvent);
}
Expand Down
41 changes: 38 additions & 3 deletions cs/src/Connections/WebSocketStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.WebSockets;
using System.Threading;
Expand Down Expand Up @@ -97,12 +98,46 @@ public string? CloseStatusDescription
set => this.closeStatusDescription = value;
}

/// <summary>
/// Gets the HTTP request ID from the web socket connection, if available.
/// </summary>
/// <remarks>
/// The request ID is returned as a response HTTP header when the websocket connection
/// is established. The value can then be added to client-reported events to support
/// correlation with service events.
/// </remarks>
public string? RequestId
{
get
{
#if NET8_0_OR_GREATER
var responseHeaders = (this.socket as ClientWebSocket)?.HttpResponseHeaders;
if (responseHeaders?.TryGetValue("VsSaaS-Request-ID", out var requestIdValues) == true)
{
var requestId = requestIdValues.FirstOrDefault();
if (!string.IsNullOrEmpty(requestId))
{
return requestId;
}
}
#endif

return null;
}
}

/// <summary>
/// Connect to web socket.
/// </summary>
public static async Task<WebSocketStream> ConnectToWebSocketAsync(Uri uri, Action<ClientWebSocketOptions>? configure = default, TraceSource? trace = default, CancellationToken cancellation = default)
{
var socket = new ClientWebSocket();

#if NET8_0_OR_GREATER
// Enable access to HTTP response headers.
socket.Options.CollectHttpResponseDetails = true;
#endif

try
{
configure?.Invoke(socket.Options);
Expand All @@ -125,11 +160,11 @@ public static async Task<WebSocketStream> ConnectToWebSocketAsync(Uri uri, Actio
if (i >= 0)
{
int j = wse.Message.IndexOf('\'', i + 1);
if (j > i + 1 &&
if (j > i + 1 &&
int.TryParse(
wse.Message.Substring(i + 1, j - i - 1),
wse.Message.Substring(i + 1, j - i - 1),
NumberStyles.None,
CultureInfo.InvariantCulture,
CultureInfo.InvariantCulture,
out var statusCode) &&
statusCode != 101)
{
Expand Down
Loading