diff --git a/.gitignore b/.gitignore index 85ce7798..12a4e183 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ BenchmarkDotNet.Artifacts/ .gradle/ src/SignalR/clients/**/dist/ modules/ +.idea # File extensions *.aps diff --git a/samples/ClientApplication/ClientApplication.csproj b/samples/ClientApplication/ClientApplication.csproj index 9f2037a7..b626a73d 100644 --- a/samples/ClientApplication/ClientApplication.csproj +++ b/samples/ClientApplication/ClientApplication.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.1;net6.0 + netcoreapp3.1;net6.0;net7.0 diff --git a/samples/ServerApplication/ServerApplication.csproj b/samples/ServerApplication/ServerApplication.csproj index 8ffeecc9..68f20e19 100644 --- a/samples/ServerApplication/ServerApplication.csproj +++ b/samples/ServerApplication/ServerApplication.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.1;net6.0 + netcoreapp3.1;net6.0;net7.0 diff --git a/src/Bedrock.Framework/Hosting/ServerHostedService.cs b/src/Bedrock.Framework/Hosting/ServerHostedService.cs index 99e23485..dc5edfee 100644 --- a/src/Bedrock.Framework/Hosting/ServerHostedService.cs +++ b/src/Bedrock.Framework/Hosting/ServerHostedService.cs @@ -1,5 +1,8 @@ -using System.Threading; +using System; +using System.Net; +using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Options; @@ -23,5 +26,15 @@ public Task StopAsync(CancellationToken cancellationToken) { return _server.StopAsync(cancellationToken); } + + public Task AddSocketListenerAsync(EndPoint endpoint, Action configure) + { + return _server.AddSocketListenerAsync(endpoint, configure); + } + + public Task RemoveSocketListenerAsync(EndPoint endpoint) + { + return _server.RemoveSocketListener(endpoint); + } } } diff --git a/src/Bedrock.Framework/Server/Server.cs b/src/Bedrock.Framework/Server/Server.cs index 3253314f..281b6d5e 100644 --- a/src/Bedrock.Framework/Server/Server.cs +++ b/src/Bedrock.Framework/Server/Server.cs @@ -1,11 +1,14 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; namespace Bedrock.Framework { @@ -13,10 +16,12 @@ public class Server { private readonly ServerBuilder _builder; private readonly ILogger _logger; - private readonly List _listeners = new List(); + private readonly Dictionary _listeners = new Dictionary(); private readonly TaskCompletionSource _shutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly TimerAwaitable _timerAwaitable; + private readonly SemaphoreSlim _listenerSemaphore = new SemaphoreSlim(initialCount: 1); private Task _timerTask = Task.CompletedTask; + private int _stopping; internal Server(ServerBuilder builder) { @@ -29,7 +34,7 @@ public IEnumerable EndPoints { get { - foreach (var listener in _listeners) + foreach (var listener in _listeners.Values) { yield return listener.Listener.EndPoint; } @@ -42,12 +47,7 @@ public async Task StartAsync(CancellationToken cancellationToken = default) { foreach (var binding in _builder.Bindings) { - await foreach (var listener in binding.BindAsync(cancellationToken).ConfigureAwait(false)) - { - var runningListener = new RunningListener(this, binding, listener); - _listeners.Add(runningListener); - runningListener.Start(); - } + await StartRunningListenersAsync(binding, cancellationToken).ConfigureAwait(false); } } catch @@ -67,7 +67,7 @@ private async Task StartTimerAsync() { while (await _timerAwaitable) { - foreach (var listener in _listeners) + foreach (var listener in _listeners.Values) { listener.TickHeartbeat(); } @@ -77,40 +77,132 @@ private async Task StartTimerAsync() public async Task StopAsync(CancellationToken cancellationToken = default) { - var tasks = new Task[_listeners.Count]; + if (Interlocked.Exchange(ref _stopping, 1) == 1) + { + return; + } + + await _listenerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + var listeners = _listeners.Values.ToList(); + + var tasks = new Task[listeners.Count]; + + for (int i = 0; i < listeners.Count; i++) + { + tasks[i] = listeners[i].Listener.UnbindAsync(cancellationToken).AsTask(); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + + // Signal to all of the listeners that it's time to start the shutdown process + // We call this after unbind so that we're not touching the listener anymore (each loop will dispose the listener) + _shutdownTcs.TrySetResult(null); + + for (int i = 0; i < listeners.Count; i++) + { + tasks[i] = listeners[i].ExecutionTask; + } + + var shutdownTask = Task.WhenAll(tasks); + + if (cancellationToken.CanBeCanceled) + { + await shutdownTask.WithCancellation(cancellationToken).ConfigureAwait(false); + } + else + { + await shutdownTask.ConfigureAwait(false); + } + + if (_timerAwaitable != null) + { + _timerAwaitable.Stop(); - for (int i = 0; i < _listeners.Count; i++) + await _timerTask.ConfigureAwait(false); + } + } + finally { - tasks[i] = _listeners[i].Listener.UnbindAsync(cancellationToken).AsTask(); + _listenerSemaphore.Release(); } + } - await Task.WhenAll(tasks).ConfigureAwait(false); + public Task AddSocketListenerAsync(EndPoint endpoint, Action configure, CancellationToken cancellationToken = default) + { + var socketTransportFactory = new SocketTransportFactory(Options.Create(new SocketTransportOptions()), _builder.ApplicationServices.GetLoggerFactory()); + var connectionBuilder = new ConnectionBuilder(_builder.ApplicationServices); - // Signal to all of the listeners that it's time to start the shutdown process - // We call this after unbind so that we're not touching the listener anymore (each loop will dispose the listener) - _shutdownTcs.TrySetResult(null); + configure(connectionBuilder); - for (int i = 0; i < _listeners.Count; i++) + var binding = new EndPointBinding(endpoint, connectionBuilder.Build(), socketTransportFactory); + return StartRunningListenersAsync(binding, cancellationToken); + } + + public async Task RemoveSocketListener(EndPoint endpoint, CancellationToken cancellationToken = default) + { + await _listenerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + + if (_stopping == 1) { - tasks[i] = _listeners[i].ExecutionTask; + throw new InvalidOperationException("The server has already been stopped."); } - var shutdownTask = Task.WhenAll(tasks); + try + { + if (!_listeners.Remove(endpoint, out var listener)) + { + return; + } + + await listener.Listener.UnbindAsync(cancellationToken).ConfigureAwait(false); - if (cancellationToken.CanBeCanceled) + // Signal to the listener that it's time to start the shutdown process + // We call this after unbind so that we're not touching the listener anymore + listener.ShutdownTcs.TrySetResult(null); + + if (cancellationToken.CanBeCanceled) + { + await listener.ExecutionTask.WithCancellation(cancellationToken).ConfigureAwait(false); + } + else + { + await listener.ExecutionTask.ConfigureAwait(false); + } + } + finally { - await shutdownTask.WithCancellation(cancellationToken).ConfigureAwait(false); + _listenerSemaphore.Release(); } - else + } + + private async Task StartRunningListenersAsync(ServerBinding binding, CancellationToken cancellationToken = default) + { + await _listenerSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + + if (_stopping == 1) { - await shutdownTask.ConfigureAwait(false); + throw new InvalidOperationException("The server has already been stopped."); } - if (_timerAwaitable != null) + try { - _timerAwaitable.Stop(); + await foreach (var listener in binding.BindAsync(cancellationToken).ConfigureAwait(false)) + { + var runningListener = new RunningListener(this, binding, listener); + if (!_listeners.TryAdd(runningListener.Listener.EndPoint, runningListener)) + { + _logger.LogWarning("Will not start RunningListener, EndPoint already exist"); + continue; + } - await _timerTask.ConfigureAwait(false); + runningListener.Start(); + } + } + finally + { + _listenerSemaphore.Release(); } } @@ -130,10 +222,12 @@ public RunningListener(Server server, ServerBinding binding, IConnectionListener public void Start() { ExecutionTask = RunListenerAsync(); + ShutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } public IConnectionListener Listener { get; } public Task ExecutionTask { get; private set; } + public TaskCompletionSource ShutdownTcs { get; private set; } public void TickHeartbeat() { @@ -215,8 +309,11 @@ async Task ExecuteConnectionAsync(ServerConnection serverConnection) id++; } - // Don't shut down connections until entire server is shutting down - await _server._shutdownTcs.Task.ConfigureAwait(false); + // Don't shut down connections until this listener or the entire server is shutting down + await Task.WhenAny( + ShutdownTcs.Task, + _server._shutdownTcs.Task) + .ConfigureAwait(false); // Give connections a chance to close gracefully var tasks = new List(_connections.Count); @@ -241,7 +338,6 @@ async Task ExecuteConnectionAsync(ServerConnection serverConnection) await listener.DisposeAsync().ConfigureAwait(false); } - private IDisposable BeginConnectionScope(ServerConnection connection) { if (_server._logger.IsEnabled(LogLevel.Critical)) @@ -253,4 +349,4 @@ private IDisposable BeginConnectionScope(ServerConnection connection) } } } -} +} \ No newline at end of file diff --git a/tests/Bedrock.Framework.Tests/ServerTests.cs b/tests/Bedrock.Framework.Tests/ServerTests.cs new file mode 100644 index 00000000..ae9d44b2 --- /dev/null +++ b/tests/Bedrock.Framework.Tests/ServerTests.cs @@ -0,0 +1,133 @@ +using System; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Bedrock.Framework.Tests +{ + public class ServerTests + { + [Fact] + public async Task StartSocketWithServer() + { + var (server, testResult) = await StartServer(); + + var expected = "Hello hello!"; + await StartClient(5000, expected); + + await Task.WhenAny( + testResult.Completion.Task, + Task.Delay(TimeSpan.FromSeconds(5)) + ); + + Assert.True(testResult.Completion.Task.IsCompleted); + Assert.Equal(expected, testResult.Completion.Task.Result); + + await server.StopAsync(); + } + + [Fact] + public async Task StartSocketAfterServer() + { + var (server, testResult) = await StartServer(); + + var endpoint = new IPEndPoint(IPAddress.Loopback, 5001); + await server.AddSocketListenerAsync(endpoint, + builder => builder.UseConnectionHandler()); + + const string expected = "Hello hello!"; + await StartClient(5001, expected); + + await Task.WhenAny( + testResult.Completion.Task, + Task.Delay(TimeSpan.FromSeconds(5)) + ); + + Assert.True(testResult.Completion.Task.IsCompleted); + Assert.Equal(expected, testResult.Completion.Task.Result); + + await server.StopAsync(); + } + + [Fact] + public async Task StopSocketBeforeServer() + { + var (server, _) = await StartServer(); + + Assert.NotNull(server.EndPoints.SingleOrDefault(x => x is IPEndPoint endpoint && endpoint.Address.Equals(IPAddress.Loopback) && endpoint.Port == 5000)); + var endpointToRemove = new IPEndPoint(IPAddress.Loopback, 5000); + await server.RemoveSocketListener(endpointToRemove); + Assert.Null(server.EndPoints.SingleOrDefault(x => x is IPEndPoint endpoint && endpoint.Address.Equals(IPAddress.Loopback) && endpoint.Port == 5000)); + + await server.StopAsync(); + } + + private static async Task StartClient(int port, string input) + { + // Setup Client + var clientServiceProvider = new ServiceCollection().BuildServiceProvider(); + + var client = new ClientBuilder(clientServiceProvider) + .UseSockets() + .Build(); + + var connection = await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port)); + + var reads = new MemoryStream(Encoding.UTF8.GetBytes(input)).CopyToAsync(connection.Transport.Output); + await reads; + } + + private static async Task<(Server server, TestResult testResult)> StartServer() + { + var services = new ServiceCollection().AddScoped(); + var serviceProvider = services.BuildServiceProvider(); + + var server = new ServerBuilder(serviceProvider) + .UseSockets(socketsServerBuilder => + socketsServerBuilder.ListenLocalhost(5000, builder => + builder.UseConnectionHandler())) + .Build(); + + await server.StartAsync(); + + var testResult = serviceProvider.GetRequiredService(); + return (server, testResult); + } + } + + internal class TestResult + { + public TaskCompletionSource Completion { get; } = new TaskCompletionSource(); + } + + internal class TestApplication : ConnectionHandler + { + private readonly TestResult _testResult; + + public TestApplication(TestResult testResult) + { + _testResult = testResult; + } + + public override async Task OnConnectedAsync(ConnectionContext connection) + { + try + { + var requestBodyInBytes = await connection.Transport.Input.ReadAsync(); + connection.Transport.Input.AdvanceTo(requestBodyInBytes.Buffer.Start, requestBodyInBytes.Buffer.End); + var input = Encoding.UTF8.GetString(requestBodyInBytes.Buffer.FirstSpan); + _testResult.Completion.SetResult(input); + } + catch + { + // Connection closed + } + } + } +} \ No newline at end of file