Skip to content

Commit 925e8c7

Browse files
committed
Added initial write-side API. Added write smoke tests.
1 parent bced7f0 commit 925e8c7

File tree

3 files changed

+292
-8
lines changed

3 files changed

+292
-8
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
using System;
2+
using System.Buffers;
3+
using System.Collections.Generic;
4+
using System.IO.Pipelines;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace Bedrock.Framework.Protocols.WebSockets
10+
{
11+
/// <summary>
12+
/// A writer-like construct for writing WebSocket messages.
13+
/// </summary>
14+
public class WebSocketMessageWriter
15+
{
16+
/// <summary>
17+
/// True if a message is in progress, false otherwise.
18+
/// </summary>
19+
internal bool _messageInProgress;
20+
21+
/// <summary>
22+
/// The current protocol type for this writer, client or server.
23+
/// </summary>
24+
private WebSocketProtocolType _protocolType;
25+
26+
/// <summary>
27+
/// The transport to write to.
28+
/// </summary>
29+
private PipeWriter _transport;
30+
31+
/// <summary>
32+
/// True if the current message is text, false otherwise.
33+
/// </summary>
34+
public bool _isText;
35+
36+
/// <summary>
37+
/// Creates an instance of a WebSocketMessageWriter.
38+
/// </summary>
39+
/// <param name="transport">The transport to write to.</param>
40+
/// <param name="protocolType">The protocol type for this writer.</param>
41+
public WebSocketMessageWriter(PipeWriter transport, WebSocketProtocolType protocolType)
42+
{
43+
_transport = transport;
44+
_protocolType = protocolType;
45+
}
46+
47+
/// <summary>
48+
/// Starts a message in the writer.
49+
/// </summary>
50+
/// <param name="payload">The payload to write.</param>
51+
/// <param name="isText">Whether the payload is text or not.</param>
52+
/// <param name="cancellationToken">A cancellation token, if any.</param>
53+
internal ValueTask StartMessageAsync(ReadOnlySequence<byte> payload, bool isText, CancellationToken cancellationToken = default)
54+
{
55+
if(_messageInProgress)
56+
{
57+
ThrowMessageAlreadyStarted();
58+
}
59+
60+
_messageInProgress = true;
61+
_isText = isText;
62+
return DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, false, payload, cancellationToken);
63+
}
64+
65+
/// <summary>
66+
/// Writes a single frame message with the writer.
67+
/// </summary>
68+
/// <param name="payload">The payload to write.</param>
69+
/// <param name="isText">Whether the payload is text or not.</param>
70+
/// <param name="cancellationToken">A cancellation token, if any.</param>
71+
internal ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence<byte> payload, bool isText, CancellationToken cancellationToken = default)
72+
{
73+
if (_messageInProgress)
74+
{
75+
ThrowMessageAlreadyStarted();
76+
}
77+
78+
var result = DoWriteAsync(isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary, true, payload, cancellationToken);
79+
_messageInProgress = false;
80+
81+
return result;
82+
}
83+
84+
/// <summary>
85+
/// Writes a message payload portion with the writer.
86+
/// </summary>
87+
/// <param name="payload">The payload to write.</param>
88+
/// <param name="cancellationToken">A cancellation token, if any.</param>
89+
public ValueTask WriteAsync(ReadOnlySequence<byte> payload, CancellationToken cancellationToken = default)
90+
{
91+
if (!_messageInProgress)
92+
{
93+
ThrowMessageNotStarted();
94+
}
95+
96+
return DoWriteAsync(WebSocketOpcode.Continuation, false, payload, cancellationToken);
97+
}
98+
99+
/// <summary>
100+
/// Ends a message in progress.
101+
/// </summary>
102+
/// <param name="payload">The payload to write.</param>
103+
/// <param name="cancellationToken">A cancellation token, if any.</param>
104+
public ValueTask EndMessageAsync(ReadOnlySequence<byte> payload, CancellationToken cancellationToken = default)
105+
{
106+
if(!_messageInProgress)
107+
{
108+
ThrowMessageNotStarted();
109+
}
110+
111+
var result = DoWriteAsync(WebSocketOpcode.Continuation, true, payload, cancellationToken);
112+
_messageInProgress = false;
113+
114+
return result;
115+
}
116+
117+
/// <summary>
118+
/// Sends a message payload portion.
119+
/// </summary>
120+
/// <param name="opcode">The WebSocket opcode to send.</param>
121+
/// <param name="endOfMessage">Whether or not this payload portion represents the end of the message.</param>
122+
/// <param name="payload">The payload to send.</param>
123+
/// <param name="cancellationToken">A cancellation token, if any.</param>
124+
private ValueTask DoWriteAsync(WebSocketOpcode opcode, bool endOfMessage, ReadOnlySequence<byte> payload, CancellationToken cancellationToken)
125+
{
126+
var masked = _protocolType == WebSocketProtocolType.Client;
127+
var header = new WebSocketHeader(endOfMessage, opcode, masked, (ulong)payload.Length, masked ? WebSocketHeader.GenerateMaskingKey() : 0);
128+
129+
var frame = new WebSocketWriteFrame(header, payload);
130+
var writer = new WebSocketFrameWriter();
131+
132+
writer.WriteMessage(frame, _transport);
133+
var flushTask = _transport.FlushAsync(cancellationToken);
134+
if (flushTask.IsCompletedSuccessfully)
135+
{
136+
var result = flushTask.Result;
137+
if(result.IsCanceled)
138+
{
139+
ThrowMessageCanceled();
140+
}
141+
142+
if (result.IsCompleted && !endOfMessage)
143+
{
144+
ThrowTransportClosed();
145+
}
146+
147+
return new ValueTask();
148+
}
149+
else
150+
{
151+
return PerformFlushAsync(flushTask, endOfMessage);
152+
}
153+
}
154+
155+
/// <summary>
156+
/// Performs a flush of the writer asynchronously.
157+
/// </summary>
158+
/// <param name="flushTask">The active writer flush task.</param>
159+
/// <param name="endOfMessage">Whether or not this flush will send an end-of-message.</param>
160+
/// <returns></returns>
161+
private async ValueTask PerformFlushAsync(ValueTask<FlushResult> flushTask, bool endOfMessage)
162+
{
163+
var result = await flushTask.ConfigureAwait(false);
164+
if (result.IsCanceled)
165+
{
166+
ThrowMessageCanceled();
167+
}
168+
169+
if (result.IsCompleted && !endOfMessage)
170+
{
171+
ThrowTransportClosed();
172+
}
173+
}
174+
175+
/// <summary>
176+
/// Throws that a message was canceled unexpectedly.
177+
/// </summary>
178+
private void ThrowMessageCanceled()
179+
{
180+
throw new OperationCanceledException("Flush was canceled while a write was still in progress.");
181+
}
182+
183+
/// <summary>
184+
/// Throws that the underlying transport closed unexpectedly.
185+
/// </summary>
186+
private void ThrowTransportClosed()
187+
{
188+
throw new InvalidOperationException("Transport closed unexpectedly while a message is still in progress.");
189+
}
190+
191+
/// <summary>
192+
/// Throws if a message has not yet been started.
193+
/// </summary>
194+
private void ThrowMessageNotStarted()
195+
{
196+
throw new InvalidOperationException("Cannot end a message if a message has not been started.");
197+
}
198+
199+
/// <summary>
200+
/// Throws if a message has already been started.
201+
/// </summary>
202+
private void ThrowMessageAlreadyStarted()
203+
{
204+
throw new InvalidOperationException("Cannot start a message when a message is already in progress.");
205+
}
206+
}
207+
}

src/Bedrock.Framework/Protocols/WebSockets/WebSocketProtocol.cs

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ public class WebSocketProtocol : IControlFrameHandler
3232
/// </summary>
3333
private WebSocketMessageReader _messageReader;
3434

35+
/// <summary>
36+
/// The shared WebSocket message writer.
37+
/// </summary>
38+
private WebSocketMessageWriter _messageWriter;
39+
3540
/// <summary>
3641
/// The type of WebSocket protocol, server or client.
3742
/// </summary>
@@ -46,6 +51,7 @@ public WebSocketProtocol(ConnectionContext connection, WebSocketProtocolType pro
4651
{
4752
_transport = connection.Transport;
4853
_messageReader = new WebSocketMessageReader(_transport.Input, this);
54+
_messageWriter = new WebSocketMessageWriter(_transport.Output, protocolType);
4955
_protocolType = protocolType;
5056
}
5157

@@ -85,23 +91,61 @@ private async ValueTask<WebSocketReadResult> DoReadAsync(ValueTask<bool> readTas
8591
/// <param name="message">The message to write.</param>
8692
/// <param name="isText">True if the message is a text type message, false otherwise.</param>
8793
/// <param name="cancellationToken">A cancellation token, if any.</param>
88-
public async ValueTask WriteMessageAsync(ReadOnlySequence<byte> message, bool isText, CancellationToken cancellationToken = default)
94+
/// <returns>The WebSocket message writer.</returns>
95+
public ValueTask<WebSocketMessageWriter> StartMessageAsync(ReadOnlySequence<byte> message, bool isText, CancellationToken cancellationToken = default)
8996
{
9097
if (IsClosed)
9198
{
9299
throw new InvalidOperationException("A close message was already received from the remote endpoint.");
93100
}
94101

95-
var opcode = isText ? WebSocketOpcode.Text : WebSocketOpcode.Binary;
96-
var masked = _protocolType == WebSocketProtocolType.Client;
102+
var startMessageTask = _messageWriter.StartMessageAsync(message, isText, cancellationToken);
103+
if(startMessageTask.IsCompletedSuccessfully)
104+
{
105+
return new ValueTask<WebSocketMessageWriter>(_messageWriter);
106+
}
107+
else
108+
{
109+
return DoStartMessageAsync(startMessageTask);
110+
}
111+
}
97112

98-
var header = new WebSocketHeader(true, opcode, masked, (ulong)message.Length, WebSocketHeader.GenerateMaskingKey());
113+
public ValueTask WriteSingleFrameMessageAsync(ReadOnlySequence<byte> message, bool isText, CancellationToken cancellationToken = default)
114+
{
115+
if (IsClosed)
116+
{
117+
throw new InvalidOperationException("A close message was already received from the remote endpoint.");
118+
}
99119

100-
var frame = new WebSocketWriteFrame(header, message);
101-
var writer = new WebSocketFrameWriter();
120+
var writerTask = _messageWriter.WriteSingleFrameMessageAsync(message, isText, cancellationToken);
121+
if (writerTask.IsCompletedSuccessfully)
122+
{
123+
return writerTask;
124+
}
125+
else
126+
{
127+
return DoWriteSingleFrameAsync(writerTask);
128+
}
129+
}
130+
131+
/// <summary>
132+
/// Completes a start message task.
133+
/// </summary>
134+
/// <param name="startMessageTask">The active start message task.</param>
135+
/// <returns>The WebSocket message writer.</returns>
136+
private async ValueTask<WebSocketMessageWriter> DoStartMessageAsync(ValueTask startMessageTask)
137+
{
138+
await startMessageTask;
139+
return _messageWriter;
140+
}
102141

103-
writer.WriteMessage(frame, _transport.Output);
104-
await _transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false);
142+
/// <summary>
143+
/// Completes a single frame writer task.
144+
/// </summary>
145+
/// <param name="writeFrameTask">The active writer task.</param>
146+
private async ValueTask DoWriteSingleFrameAsync(ValueTask writeFrameTask)
147+
{
148+
await writeFrameTask;
105149
}
106150

107151
/// <summary>

tests/Bedrock.Framework.Tests/Protocols/WebSocketProtocolTests.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ namespace Bedrock.Framework.Tests.Protocols
1515
{
1616
public class WebSocketProtocolTests
1717
{
18+
private byte[] _buffer = new byte[4096];
19+
1820
[Fact]
1921
public async Task SingleMessageWorks()
2022
{
@@ -91,5 +93,36 @@ public async Task MessageWithMultipleFramesWorks()
9193

9294
Assert.Equal(payloadString, Encoding.UTF8.GetString(buffer.WrittenSpan));
9395
}
96+
97+
[Fact]
98+
public async Task WriteSingleMessageWorks()
99+
{
100+
var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false));
101+
var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server);
102+
103+
var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30));
104+
var payloadString = "This is a test payload.";
105+
await protocol.WriteSingleFrameMessageAsync(new ReadOnlySequence<byte>(new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes(payloadString))), false, default);
106+
107+
var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(_buffer), default);
108+
Assert.Equal(payloadString, Encoding.UTF8.GetString(_buffer, 0, result.Count));
109+
}
110+
111+
[Fact]
112+
public async Task WriteMultipleFramesWorks()
113+
{
114+
var context = new InMemoryConnectionContext(new PipeOptions(useSynchronizationContext: false));
115+
var protocol = new WebSocketProtocol(context, WebSocketProtocolType.Server);
116+
117+
var webSocket = WebSocket.CreateFromStream(new DuplexPipeStream(context.Application.Input, context.Application.Output), false, null, TimeSpan.FromSeconds(30));
118+
var payloadString = "This is a test payload.";
119+
var writer = await protocol.StartMessageAsync(new ReadOnlySequence<byte>(new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes(payloadString))), false, default);
120+
await writer.EndMessageAsync(new ReadOnlySequence<byte>(new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes(payloadString))));
121+
122+
var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(_buffer, 0, _buffer.Length), default);
123+
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(_buffer, result.Count, _buffer.Length - result.Count), default);
124+
125+
Assert.Equal($"{payloadString}{payloadString}", Encoding.UTF8.GetString(_buffer, 0, result.Count * 2));
126+
}
94127
}
95128
}

0 commit comments

Comments
 (0)