diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b67469b..d51838a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,16 +1,14 @@ name: test on: + push: + branches: [ main ] pull_request: - push: { branches: [ main ] } - + branches: [ main ] + workflow_dispatch: jobs: + lint: + uses: graphqlswift/ci/.github/workflows/lint.yaml@main test: - strategy: - matrix: - os: [ubuntu-latest, macos-latest] - runs-on: ${{ matrix.os }} - steps: - - uses: fwal/setup-swift@v1 - - uses: actions/checkout@v2 - - name: Run tests - run: swift test + uses: graphqlswift/ci/.github/workflows/test.yaml@main + with: + include_android: false diff --git a/Package.resolved b/Package.resolved index 58ce4d2..fd423c5 100644 --- a/Package.resolved +++ b/Package.resolved @@ -6,8 +6,8 @@ "repositoryURL": "https://github.com/GraphQLSwift/Graphiti.git", "state": { "branch": null, - "revision": "c9bc9d1cc9e62e71a824dc178630bfa8b8a6e2a4", - "version": "1.0.0" + "revision": "a23a3d232df202fc158ad2d698926325b470523c", + "version": "3.0.0" } }, { @@ -15,26 +15,8 @@ "repositoryURL": "https://github.com/GraphQLSwift/GraphQL.git", "state": { "branch": null, - "revision": "283cc4de56b994a00b2724328221b7a1bc846ddc", - "version": "2.2.1" - } - }, - { - "package": "GraphQLRxSwift", - "repositoryURL": "https://github.com/GraphQLSwift/GraphQLRxSwift.git", - "state": { - "branch": null, - "revision": "c7ec6595f92ef5d77c06852e4acc4cd46a753622", - "version": "0.0.4" - } - }, - { - "package": "RxSwift", - "repositoryURL": "https://github.com/ReactiveX/RxSwift.git", - "state": { - "branch": null, - "revision": "b4307ba0b6425c0ba4178e138799946c3da594f8", - "version": "6.5.0" + "revision": "0fe18bc0bbbc9ab8929c285f419adea7c8fc7da2", + "version": "4.0.1" } }, { @@ -42,17 +24,8 @@ "repositoryURL": "https://github.com/apple/swift-collections", "state": { "branch": null, - "revision": "48254824bb4248676bf7ce56014ff57b142b77eb", - "version": "1.0.2" - } - }, - { - "package": "swift-nio", - "repositoryURL": "https://github.com/apple/swift-nio.git", - "state": { - "branch": null, - "revision": "154f1d32366449dcccf6375a173adf4ed2a74429", - "version": "2.38.0" + "revision": "8c0c0a8b49e080e54e5e328cc552821ff07cd341", + "version": "1.2.1" } } ] diff --git a/Package.swift b/Package.swift index 07e7f47..c8c29db 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,7 @@ import PackageDescription let package = Package( name: "GraphQLWS", + platforms: [.macOS(.v10_15)], products: [ .library( name: "GraphQLWS", @@ -11,25 +12,21 @@ let package = Package( ), ], dependencies: [ - .package(name: "Graphiti", url: "https://github.com/GraphQLSwift/Graphiti.git", from: "1.0.0"), - .package(name: "GraphQL", url: "https://github.com/GraphQLSwift/GraphQL.git", from: "2.2.1"), - .package(name: "GraphQLRxSwift", url: "https://github.com/GraphQLSwift/GraphQLRxSwift.git", from: "0.0.4"), - .package(name: "RxSwift", url: "https://github.com/ReactiveX/RxSwift.git", from: "6.1.0"), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.33.0") + .package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"), + .package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.1"), ], targets: [ .target( name: "GraphQLWS", dependencies: [ .product(name: "Graphiti", package: "Graphiti"), - .product(name: "GraphQLRxSwift", package: "GraphQLRxSwift"), .product(name: "GraphQL", package: "GraphQL"), - .product(name: "NIO", package: "swift-nio"), - .product(name: "RxSwift", package: "RxSwift") - ]), + ] + ), .testTarget( name: "GraphQLWSTests", dependencies: ["GraphQLWS"] ), - ] + ], + swiftLanguageVersions: [.v5, .version("6")] ) diff --git a/README.md b/README.md index eebda12..b97e2c4 100644 --- a/README.md +++ b/README.md @@ -28,31 +28,31 @@ import GraphQLWS class WebSocketMessenger: Messenger { private weak var websocket: WebSocket? private var onReceive: (String) -> Void = { _ in } - + init(websocket: WebSocket) { self.websocket = websocket websocket.onText { _, message in - self.onReceive(message) + try await self.onReceive(message) } } - - func send(_ message: S) where S: Collection, S.Element == Character { + + func send(_ message: S) async throws where S: Collection, S.Element == Character async throws { guard let websocket = websocket else { return } - websocket.send(message) + try await websocket.send(message) } - - func onReceive(callback: @escaping (String) -> Void) { + + func onReceive(callback: @escaping (String) async throws -> Void) { self.onReceive = callback } - - func error(_ message: String, code: Int) { + + func error(_ message: String, code: Int) async throws { guard let websocket = websocket else { return } - websocket.send("\(code): \(message)") + try await websocket.send("\(code): \(message)") } - - func close() { + + func close() async throws { guard let websocket = websocket else { return } - _ = websocket.close() + try await websocket.close() } } ``` @@ -67,7 +67,7 @@ routes.webSocket( let server = GraphQLWS.Server( messenger: messenger, onExecute: { graphQLRequest in - api.execute( + try await api.execute( request: graphQLRequest.query, context: context, on: self.eventLoop, @@ -76,7 +76,7 @@ routes.webSocket( ) }, onSubscribe: { graphQLRequest in - api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context, on: self.eventLoop, @@ -128,8 +128,8 @@ If the `payload` field is not required on your server, you may make Server's gen ## Memory Management -Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket -implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the +Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket +implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes, diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index 42a7fd8..269f267 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -5,18 +5,18 @@ import GraphQL public class Client { // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - var onConnectionError: (ConnectionErrorResponse, Client) -> Void = { _, _ in } - var onConnectionAck: (ConnectionAckResponse, Client) -> Void = { _, _ in } - var onConnectionKeepAlive: (ConnectionKeepAliveResponse, Client) -> Void = { _, _ in } - var onData: (DataResponse, Client) -> Void = { _, _ in } - var onError: (ErrorResponse, Client) -> Void = { _, _ in } - var onComplete: (CompleteResponse, Client) -> Void = { _, _ in } - var onMessage: (String, Client) -> Void = { _, _ in } - + + var onConnectionError: (ConnectionErrorResponse, Client) async throws -> Void = { _, _ in } + var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } + var onConnectionKeepAlive: (ConnectionKeepAliveResponse, Client) async throws -> Void = { _, _ in } + var onData: (DataResponse, Client) async throws -> Void = { _, _ in } + var onError: (ErrorResponse, Client) async throws -> Void = { _, _ in } + var onComplete: (CompleteResponse, Client) async throws -> Void = { _, _ in } + var onMessage: (String, Client) async throws -> Void = { _, _ in } + let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() - + /// Create a new client. /// /// - Parameters: @@ -26,155 +26,154 @@ public class Client { ) { self.messenger = messenger messenger.onReceive { message in - self.onMessage(message, self) - + try await self.onMessage(message, self) + // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages return } - + guard let json = message.data(using: .utf8) else { - self.error(.invalidEncoding()) + try await self.error(.invalidEncoding()) return } - + let response: Response do { response = try self.decoder.decode(Response.self, from: json) - } - catch { - self.error(.noType()) + } catch { + try await self.error(.noType()) return } - + switch response.type { - case .GQL_CONNECTION_ERROR: - guard let connectionErrorResponse = try? self.decoder.decode(ConnectionErrorResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) - return - } - self.onConnectionError(connectionErrorResponse, self) - case .GQL_CONNECTION_ACK: - guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) - return - } - self.onConnectionAck(connectionAckResponse, self) - case .GQL_CONNECTION_KEEP_ALIVE: - guard let connectionKeepAliveResponse = try? self.decoder.decode(ConnectionKeepAliveResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) - return - } - self.onConnectionKeepAlive(connectionKeepAliveResponse, self) - case .GQL_DATA: - guard let nextResponse = try? self.decoder.decode(DataResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .GQL_DATA)) - return - } - self.onData(nextResponse, self) - case .GQL_ERROR: - guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .GQL_ERROR)) - return - } - self.onError(errorResponse, self) - case .GQL_COMPLETE: - guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) - return - } - self.onComplete(completeResponse, self) - case .unknown: - self.error(.invalidType()) + case .GQL_CONNECTION_ERROR: + guard let connectionErrorResponse = try? self.decoder.decode(ConnectionErrorResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + return + } + try await self.onConnectionError(connectionErrorResponse, self) + case .GQL_CONNECTION_ACK: + guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + return + } + try await self.onConnectionAck(connectionAckResponse, self) + case .GQL_CONNECTION_KEEP_ALIVE: + guard let connectionKeepAliveResponse = try? self.decoder.decode(ConnectionKeepAliveResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) + return + } + try await self.onConnectionKeepAlive(connectionKeepAliveResponse, self) + case .GQL_DATA: + guard let nextResponse = try? self.decoder.decode(DataResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .GQL_DATA)) + return + } + try await self.onData(nextResponse, self) + case .GQL_ERROR: + guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .GQL_ERROR)) + return + } + try await self.onError(errorResponse, self) + case .GQL_COMPLETE: + guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) + return + } + try await self.onComplete(completeResponse, self) + case .unknown: + try await self.error(.invalidType()) } } } - + /// Define the callback run on receipt of a `connection_error` message /// - Parameter callback: The callback to assign - public func onConnectionError(_ callback: @escaping (ConnectionErrorResponse, Client) -> Void) { - self.onConnectionError = callback + public func onConnectionError(_ callback: @escaping (ConnectionErrorResponse, Client) async throws -> Void) { + onConnectionError = callback } - + /// Define the callback run on receipt of a `connection_ack` message /// - Parameter callback: The callback to assign - public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) -> Void) { - self.onConnectionAck = callback + public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) async throws -> Void) { + onConnectionAck = callback } - + /// Define the callback run on receipt of a `connection_ka` message /// - Parameter callback: The callback to assign - public func onConnectionKeepAlive(_ callback: @escaping (ConnectionKeepAliveResponse, Client) -> Void) { - self.onConnectionKeepAlive = callback + public func onConnectionKeepAlive(_ callback: @escaping (ConnectionKeepAliveResponse, Client) async throws -> Void) { + onConnectionKeepAlive = callback } - + /// Define the callback run on receipt of a `data` message /// - Parameter callback: The callback to assign - public func onData(_ callback: @escaping (DataResponse, Client) -> Void) { - self.onData = callback + public func onData(_ callback: @escaping (DataResponse, Client) async throws -> Void) { + onData = callback } - + /// Define the callback run on receipt of an `error` message /// - Parameter callback: The callback to assign - public func onError(_ callback: @escaping (ErrorResponse, Client) -> Void) { - self.onError = callback + public func onError(_ callback: @escaping (ErrorResponse, Client) async throws -> Void) { + onError = callback } - + /// Define the callback run on receipt of any message /// - Parameter callback: The callback to assign - public func onComplete(_ callback: @escaping (CompleteResponse, Client) -> Void) { - self.onComplete = callback + public func onComplete(_ callback: @escaping (CompleteResponse, Client) async throws -> Void) { + onComplete = callback } - + /// Define the callback run on receipt of a `complete` message /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String, Client) -> Void) { - self.onMessage = callback + public func onMessage(_ callback: @escaping (String, Client) async throws -> Void) { + onMessage = callback } - + /// Send a `connection_init` request through the messenger - public func sendConnectionInit(payload: InitPayload) { + public func sendConnectionInit(payload: InitPayload) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionInitRequest( payload: payload ).toJSON(encoder) ) } - + /// Send a `start` request through the messenger - public func sendStart(payload: GraphQLRequest, id: String) { + public func sendStart(payload: GraphQLRequest, id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( StartRequest( payload: payload, id: id ).toJSON(encoder) ) } - + /// Send a `stop` request through the messenger - public func sendStop(id: String) { + public func sendStop(id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( StopRequest( id: id ).toJSON(encoder) ) } - + /// Send a `connection_terminate` request through the messenger - public func sendConnectionTerminate() { + public func sendConnectionTerminate() async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionTerminateRequest().toJSON(encoder) ) } - + /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLWSError) { + private func error(_ error: GraphQLWSError) async throws { guard let messenger = messenger else { return } - messenger.error(error.message, code: error.code.rawValue) + try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Sources/GraphQLWS/GraphQLWSError.swift b/Sources/GraphQLWS/GraphQLWSError.swift index d54d883..b5fce64 100644 --- a/Sources/GraphQLWS/GraphQLWSError.swift +++ b/Sources/GraphQLWS/GraphQLWSError.swift @@ -3,82 +3,82 @@ import GraphQL struct GraphQLWSError: Error { let message: String let code: ErrorCode - + init(_ message: String, code: ErrorCode) { self.message = message self.code = code } - + static func unauthorized() -> Self { return self.init( "Unauthorized", code: .unauthorized ) } - + static func notInitialized() -> Self { return self.init( "Connection not initialized", code: .notInitialized ) } - + static func tooManyInitializations() -> Self { return self.init( "Too many initialisation requests", code: .tooManyInitializations ) } - + static func subscriberAlreadyExists(id: String) -> Self { return self.init( "Subscriber for \(id) already exists", code: .subscriberAlreadyExists ) } - + static func invalidEncoding() -> Self { return self.init( "Message was not encoded in UTF8", code: .invalidEncoding ) } - + static func noType() -> Self { return self.init( "Message has no 'type' field", code: .noType ) } - + static func invalidType() -> Self { return self.init( "Message 'type' value does not match supported types", code: .invalidType ) } - + static func invalidRequestFormat(messageType: RequestMessageType) -> Self { return self.init( "Request message doesn't match '\(messageType.rawValue)' JSON format", code: .invalidRequestFormat ) } - + static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { return self.init( "Response message doesn't match '\(messageType.rawValue)' JSON format", code: .invalidResponseFormat ) } - + static func internalAPIStreamIssue(errors: [GraphQLError]) -> Self { return self.init( - "API Response did not result in a stream type, contained errors\n\(errors.map { $0.message}.joined(separator: "\n"))", + "API Response did not result in a stream type, contained errors\n\(errors.map { $0.message }.joined(separator: "\n"))", code: .internalAPIStreamIssue ) } - + static func graphQLError(_ error: Error) -> Self { return self.init( "\(error)", @@ -88,28 +88,28 @@ struct GraphQLWSError: Error { } /// Error codes for miscellaneous issues -public enum ErrorCode: Int, CustomStringConvertible { +public enum ErrorCode: Int, CustomStringConvertible, Sendable { // Miscellaneous case miscellaneous = 4400 - + // Internal errors case graphQLError = 4401 case internalAPIStreamIssue = 4402 - + // Message errors case invalidEncoding = 4410 case noType = 4411 case invalidType = 4412 case invalidRequestFormat = 4413 case invalidResponseFormat = 4414 - + // Initialization errors case unauthorized = 4430 case notInitialized = 4431 case tooManyInitializations = 4432 case subscriberAlreadyExists = 4433 - + public var description: String { - return "\(self.rawValue)" + return "\(rawValue)" } } diff --git a/Sources/GraphQLWS/InitPayloads.swift b/Sources/GraphQLWS/InitPayloads.swift index 41a6cc2..8a50b36 100644 --- a/Sources/GraphQLWS/InitPayloads.swift +++ b/Sources/GraphQLWS/InitPayloads.swift @@ -1,12 +1,12 @@ // Contains convenient `connection_init` payloads for users of this package /// `connection_init` `payload` that is empty -public struct EmptyInitPayload: Equatable & Codable { } +public struct EmptyInitPayload: Equatable & Codable & Sendable {} /// `connection_init` `payload` that includes an `authToken` field -public struct TokenInitPayload: Equatable & Codable { +public struct TokenInitPayload: Equatable & Codable & Sendable { public let authToken: String - + public init(authToken: String) { self.authToken = authToken } diff --git a/Sources/GraphQLWS/JsonEncodable.swift b/Sources/GraphQLWS/JsonEncodable.swift index 81a3d52..911d14f 100644 --- a/Sources/GraphQLWS/JsonEncodable.swift +++ b/Sources/GraphQLWS/JsonEncodable.swift @@ -12,8 +12,7 @@ extension JsonEncodable { let data: Data do { data = try encoder.encode(self) - } - catch { + } catch { return EncodingErrorResponse("Unable to encode response").toJSON(encoder) } guard let body = String(data: data, encoding: .utf8) else { diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index 9ccd33c..3a9c157 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -1,23 +1,22 @@ import Foundation -import NIO /// Protocol for an object that can send and recieve messages. This allows mocking in tests public protocol Messenger: AnyObject { // AnyObject compliance requires that the implementing object is a class and we can reference it weakly - + /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) -> Void where S: Collection, S.Element == Character - + func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character + /// Set the callback that should be run when a message is recieved - func onReceive(callback: @escaping (String) -> Void) -> Void - + func onReceive(callback: @escaping (String) async throws -> Void) + /// Close the messenger - func close() -> Void - + func close() async throws + /// Indicate that the messenger experienced an error. /// - Parameters: /// - message: The message describing the error /// - code: An error code - func error(_ message: String, code: Int) -> Void + func error(_ message: String, code: Int) async throws } diff --git a/Sources/GraphQLWS/Requests.swift b/Sources/GraphQLWS/Requests.swift index b1533e7..21f5abd 100644 --- a/Sources/GraphQLWS/Requests.swift +++ b/Sources/GraphQLWS/Requests.swift @@ -48,7 +48,7 @@ enum RequestMessageType: String, Codable { case GQL_STOP = "stop" case GQL_CONNECTION_TERMINATE = "connection_terminate" case unknown - + init(from decoder: Decoder) throws { guard let value = try? decoder.singleValueContainer().decode(String.self) else { self = .unknown diff --git a/Sources/GraphQLWS/Responses.swift b/Sources/GraphQLWS/Responses.swift index 51bfe3b..525fa17 100644 --- a/Sources/GraphQLWS/Responses.swift +++ b/Sources/GraphQLWS/Responses.swift @@ -10,9 +10,9 @@ public struct Response: Equatable, JsonEncodable { public struct ConnectionAckResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [String: Map]? - + init(_ payload: [String: Map]? = nil) { - self.type = .GQL_CONNECTION_ACK + type = .GQL_CONNECTION_ACK self.payload = payload } } @@ -21,9 +21,9 @@ public struct ConnectionAckResponse: Equatable, JsonEncodable { public struct ConnectionErrorResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [String: Map]? - + init(_ payload: [String: Map]? = nil) { - self.type = .GQL_CONNECTION_ERROR + type = .GQL_CONNECTION_ERROR self.payload = payload } } @@ -32,9 +32,9 @@ public struct ConnectionErrorResponse: Equatable, JsonEncodable { public struct ConnectionKeepAliveResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [String: Map]? - + init(_ payload: [String: Map]? = nil) { - self.type = .GQL_CONNECTION_KEEP_ALIVE + type = .GQL_CONNECTION_KEEP_ALIVE self.payload = payload } } @@ -44,9 +44,9 @@ public struct DataResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: GraphQLResult? public let id: String - + init(_ payload: GraphQLResult? = nil, id: String) { - self.type = .GQL_DATA + type = .GQL_DATA self.payload = payload self.id = id } @@ -56,9 +56,9 @@ public struct DataResponse: Equatable, JsonEncodable { public struct CompleteResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let id: String - + init(id: String) { - self.type = .GQL_COMPLETE + type = .GQL_COMPLETE self.id = id } } @@ -68,18 +68,18 @@ public struct ErrorResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [GraphQLError] public let id: String - + init(_ errors: [Error], id: String) { let graphQLErrors = errors.map { error -> GraphQLError in switch error { - case let graphQLError as GraphQLError: - return graphQLError - default: - return GraphQLError(error) + case let graphQLError as GraphQLError: + return graphQLError + default: + return GraphQLError(error) } } - self.type = .GQL_ERROR - self.payload = graphQLErrors + type = .GQL_ERROR + payload = graphQLErrors self.id = id } } @@ -93,7 +93,7 @@ enum ResponseMessageType: String, Codable { case GQL_ERROR = "error" case GQL_COMPLETE = "complete" case unknown - + init(from decoder: Decoder) throws { guard let value = try? decoder.singleValueContainer().decode(String.self) else { self = .unknown @@ -108,9 +108,9 @@ enum ResponseMessageType: String, Codable { struct EncodingErrorResponse: Equatable, Codable, JsonEncodable { let type: ResponseMessageType let payload: [String: String] - + init(_ errorMessage: String) { - self.type = .GQL_ERROR - self.payload = ["error": errorMessage] + type = .GQL_ERROR + payload = ["error": errorMessage] } } diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 2af4bff..24e5590 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -1,304 +1,301 @@ import Foundation import GraphQL -import GraphQLRxSwift -import NIO -import RxSwift /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server { +public class Server< + InitPayload: Equatable & Codable & Sendable, + SubscriptionSequenceType: AsyncSequence & Sendable +>: @unchecked Sendable where + SubscriptionSequenceType.Element == GraphQLResult +{ // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - let onExecute: (GraphQLRequest) -> EventLoopFuture - let onSubscribe: (GraphQLRequest) -> EventLoopFuture - var auth: (InitPayload) throws -> EventLoopFuture - - var onExit: () -> Void = { } - var onMessage: (String) -> Void = { _ in } - var onOperationComplete: (String) -> Void = { _ in } - var onOperationError: (String, [Error]) -> Void = { _, _ in } - + + let onExecute: (GraphQLRequest) async throws -> GraphQLResult + let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType + var auth: (InitPayload) async throws -> Void + + var onExit: () async throws -> Void = {} + var onMessage: (String) async throws -> Void = { _ in } + var onOperationComplete: (String) async throws -> Void = { _ in } + var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } + var initialized = false - - let disposeBag = DisposeBag() + let decoder = JSONDecoder() let encoder = GraphQLJSONEncoder() - + + private var subscriptionTasks = [String: Task]() + /// Create a new server /// /// - Parameters: /// - messenger: The messenger to bind the server to. /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. - /// - eventLoop: EventLoop on which to perform server operations. public init( messenger: Messenger, - onExecute: @escaping (GraphQLRequest) -> EventLoopFuture, - onSubscribe: @escaping (GraphQLRequest) -> EventLoopFuture, - eventLoop: EventLoop + onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, + onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType ) { self.messenger = messenger self.onExecute = onExecute self.onSubscribe = onSubscribe - self.auth = { _ in eventLoop.makeSucceededVoidFuture() } - + auth = { _ in } + messenger.onReceive { message in guard let messenger = self.messenger else { return } - - self.onMessage(message) - + + try await self.onMessage(message) + // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages return } - + guard let json = message.data(using: .utf8) else { - self.error(.invalidEncoding()) + try await self.error(.invalidEncoding()) return } - + let request: Request do { request = try self.decoder.decode(Request.self, from: json) - } - catch { - self.error(.noType()) + } catch { + try await self.error(.noType()) return } - + // handle incoming message switch request.type { - case .GQL_CONNECTION_INIT: - guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { - self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) - return - } - self.onConnectionInit(connectionInitRequest, messenger) - case .GQL_START: - guard let startRequest = try? self.decoder.decode(StartRequest.self, from: json) else { - self.error(.invalidRequestFormat(messageType: .GQL_START)) - return - } - self.onStart(startRequest, messenger) - case .GQL_STOP: - guard let stopRequest = try? self.decoder.decode(StopRequest.self, from: json) else { - self.error(.invalidRequestFormat(messageType: .GQL_STOP)) - return - } - self.onOperationComplete(stopRequest.id) - case .GQL_CONNECTION_TERMINATE: - guard let connectionTerminateRequest = try? self.decoder.decode(ConnectionTerminateRequest.self, from: json) else { - self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) - return - } - self.onConnectionTerminate(connectionTerminateRequest, messenger) - case .unknown: - self.error(.invalidType()) + case .GQL_CONNECTION_INIT: + guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) + return + } + try await self.onConnectionInit(connectionInitRequest, messenger) + case .GQL_START: + guard let startRequest = try? self.decoder.decode(StartRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .GQL_START)) + return + } + try await self.onStart(startRequest, messenger) + case .GQL_STOP: + guard let stopRequest = try? self.decoder.decode(StopRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .GQL_STOP)) + return + } + try await self.onStop(stopRequest) + case .GQL_CONNECTION_TERMINATE: + guard let connectionTerminateRequest = try? self.decoder.decode(ConnectionTerminateRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) + return + } + try await self.onConnectionTerminate(connectionTerminateRequest, messenger) + case .unknown: + try await self.error(.invalidType()) } } } - + + deinit { + subscriptionTasks.values.forEach { $0.cancel() } + } + /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. - /// Throw or fail the future from this closure to indicate that authorization has failed. + /// Throw from this closure to indicate that authorization has failed. /// - Parameter callback: The callback to assign - public func auth(_ callback: @escaping (InitPayload) throws -> EventLoopFuture) { - self.auth = callback + public func auth(_ callback: @escaping (InitPayload) async throws -> Void) { + auth = callback } - + /// Define the callback run when the communication is shut down, either by the client or server /// - Parameter callback: The callback to assign public func onExit(_ callback: @escaping () -> Void) { - self.onExit = callback + onExit = callback } - + /// Define the callback run on receipt of any message /// - Parameter callback: The callback to assign public func onMessage(_ callback: @escaping (String) -> Void) { - self.onMessage = callback + onMessage = callback } - + /// Define the callback run on the completion a full operation (query/mutation, end of subscription) /// - Parameter callback: The callback to assign public func onOperationComplete(_ callback: @escaping (String) -> Void) { - self.onOperationComplete = callback + onOperationComplete = callback } - + /// Define the callback to run on error of any full operation (failed query, interrupted subscription) /// - Parameter callback: The callback to assign public func onOperationError(_ callback: @escaping (String, [Error]) -> Void) { - self.onOperationError = callback + onOperationError = callback } - - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _ messenger: Messenger) { + + private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { guard !initialized else { - self.error(.tooManyInitializations()) + try await error(.tooManyInitializations()) return } - + do { - let authResult = try self.auth(connectionInitRequest.payload) - authResult.whenSuccess { - self.initialized = true - self.sendConnectionAck() - } - authResult.whenFailure { error in - self.error(.unauthorized()) - return - } - } - catch { - self.error(.unauthorized()) + try await auth(connectionInitRequest.payload) + } catch { + try await self.error(.unauthorized()) return } + initialized = true + try await sendConnectionAck() // TODO: Should we send the `ka` message? } - - private func onStart(_ startRequest: StartRequest, _ messenger: Messenger) { + + private func onStart(_ startRequest: StartRequest, _ messenger: Messenger) async throws { guard initialized else { - self.error(.notInitialized()) + try await error(.notInitialized()) return } - + let id = startRequest.id + if subscriptionTasks[id] != nil { + try await error(.subscriberAlreadyExists(id: id)) + } + let graphQLRequest = startRequest.payload - + var isStreaming = false do { isStreaming = try graphQLRequest.isSubscription() - } - catch { - self.sendError(error, id: id) + } catch { + try await sendError(error, id: id) return } - + if isStreaming { - let subscribeFuture = onSubscribe(graphQLRequest) - subscribeFuture.whenSuccess { result in - guard let streamOpt = result.stream else { - // API issue - subscribe resolver isn't stream - self.sendError(result.errors, id: id) - return - } - let stream = streamOpt as! ObservableSubscriptionEventStream - let observable = stream.observable - observable.subscribe( - onNext: { [weak self] resultFuture in - guard let self = self else { return } - resultFuture.whenSuccess { result in - self.sendData(result, id: id) - } - resultFuture.whenFailure { error in - self.sendError(error, id: id) - } - }, - onError: { [weak self] error in - guard let self = self else { return } - self.sendError(error, id: id) - }, - onCompleted: { [weak self] in - guard let self = self else { return } - self.sendComplete(id: id) + subscriptionTasks[id] = Task { + do { + let stream = try await onSubscribe(graphQLRequest) + for try await event in stream { + try Task.checkCancellation() + try await self.sendData(event, id: id) } - ).disposed(by: self.disposeBag) + } catch { + try await sendError(error, id: id) + subscriptionTasks.removeValue(forKey: id) + throw error + } + try await self.sendComplete(id: id) + subscriptionTasks.removeValue(forKey: id) } - subscribeFuture.whenFailure { error in - self.sendError(error, id: id) + } else { + do { + let result = try await onExecute(graphQLRequest) + try await sendData(result, id: id) + try await sendComplete(id: id) + } catch { + try await sendError(error, id: id) } + try await messenger.close() } - else { - let executeFuture = onExecute(graphQLRequest) - executeFuture.whenSuccess { result in - self.sendData(result, id: id) - self.sendComplete(id: id) - messenger.close() - } - executeFuture.whenFailure { error in - self.sendError(error, id: id) - self.sendComplete(id: id) - messenger.close() - } + } + + private func onStop(_ stopRequest: StopRequest) async throws { + guard initialized else { + try await error(.notInitialized()) + return } + + let id = stopRequest.id + if let task = subscriptionTasks[id] { + task.cancel() + subscriptionTasks.removeValue(forKey: id) + } + try await onOperationComplete(id) } - - private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) { - onExit() - _ = messenger.close() + + private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) async throws { + for (_, subscriptionTask) in subscriptionTasks { + subscriptionTask.cancel() + } + subscriptionTasks.removeAll() + try await onExit() + try await messenger.close() } - + /// Send a `connection_ack` response through the messenger - private func sendConnectionAck(_ payload: [String: Map]? = nil) { + private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionAckResponse(payload).toJSON(encoder) ) } - + /// Send a `connection_error` response through the messenger - private func sendConnectionError(_ payload: [String: Map]? = nil) { + private func sendConnectionError(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionErrorResponse(payload).toJSON(encoder) ) } - + /// Send a `ka` response through the messenger - private func sendConnectionKeepAlive(_ payload: [String: Map]? = nil) { + private func sendConnectionKeepAlive(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionKeepAliveResponse(payload).toJSON(encoder) ) } - + /// Send a `data` response through the messenger - private func sendData(_ payload: GraphQLResult? = nil, id: String) { + private func sendData(_ payload: GraphQLResult? = nil, id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( DataResponse( payload, id: id ).toJSON(encoder) ) } - + /// Send a `complete` response through the messenger - private func sendComplete(id: String) { + private func sendComplete(id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( CompleteResponse( id: id ).toJSON(encoder) ) - onOperationComplete(id) + try await onOperationComplete(id) } - + /// Send an `error` response through the messenger - private func sendError(_ errors: [Error], id: String) { + private func sendError(_ errors: [Error], id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ErrorResponse( errors, id: id ).toJSON(encoder) ) - onOperationError(id, errors) + try await onOperationError(id, errors) } - + /// Send an `error` response through the messenger - private func sendError(_ error: Error, id: String) { - self.sendError([error], id: id) + private func sendError(_ error: Error, id: String) async throws { + try await sendError([error], id: id) } - + /// Send an `error` response through the messenger - private func sendError(_ errorMessage: String, id: String) { - self.sendError(GraphQLError(message: errorMessage), id: id) + private func sendError(_ errorMessage: String, id: String) async throws { + try await sendError(GraphQLError(message: errorMessage), id: id) } - + /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLWSError) { + private func error(_ error: GraphQLWSError) async throws { guard let messenger = messenger else { return } - messenger.error(error.message, code: error.code.rawValue) + try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index 0d95ccb..0ab46da 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -1,7 +1,6 @@ import Foundation import GraphQL -import NIO import XCTest @testable import GraphQLWS @@ -9,224 +8,212 @@ import XCTest class GraphqlWsTests: XCTestCase { var clientMessenger: TestMessenger! var serverMessenger: TestMessenger! - var server: Server! - var eventLoop: EventLoop! - + var server: Server>! + var context: TestContext! + var subscribeReady: Bool! = false + override func setUp() { // Point the client and server at each other clientMessenger = TestMessenger() serverMessenger = TestMessenger() clientMessenger.other = serverMessenger serverMessenger.other = clientMessenger - - eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() + let api = TestAPI() let context = TestContext() - - server = Server( + + server = .init( messenger: serverMessenger, onExecute: { graphQLRequest in - api.execute( + try await api.execute( request: graphQLRequest.query, - context: context, - on: self.eventLoop + context: context ) }, onSubscribe: { graphQLRequest in - api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, - context: context, - on: self.eventLoop - ) - }, - eventLoop: self.eventLoop + context: context + ).get() + self.subscribeReady = true + return subscription + } ) + self.context = context } - + /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() throws { - var messages = [String]() - let completeExpectation = XCTestExpectation() - + func testInitialize() async throws { let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onMessage { message, _ in + continuation.yield(message) + // Expect only one message + continuation.finish() + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } } - - client.sendStart( + + try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: UUID().uuidString ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages, ["\(ErrorCode.notInitialized): Connection not initialized"] ) } - + /// Tests that throwing in the authorization callback forces an unauthorized error - func testAuthWithThrow() throws { - server.auth { payload in + func testAuthWithThrow() async throws { + server.auth { _ in throw TestError.couldBeAnything } - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onMessage { message, _ in + continuation.yield(message) + // Expect only one message + continuation.finish() + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) - XCTAssertEqual( - messages, - ["\(ErrorCode.unauthorized): Unauthorized"] - ) - } - - /// Tests that failing a future in the authorization callback forces an unauthorized error - func testAuthWithFailedFuture() throws { - server.auth { payload in - self.eventLoop.makeFailedFuture(TestError.couldBeAnything) - } - - var messages = [String]() - let completeExpectation = XCTestExpectation() - - let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) } - - client.sendConnectionInit( - payload: TokenInitPayload( - authToken: "" - ) - ) - - wait(for: [completeExpectation], timeout: 2) XCTAssertEqual( messages, ["\(ErrorCode.unauthorized): Unauthorized"] ) } - + /// Test single op message flow works as expected - func testSingleOp() throws { + func testSingleOp() async throws { let id = UUID().description - - // Test single-op conversation - var messages = [String]() - let completeExpectation = XCTestExpectation() - + let client = Client(messenger: clientMessenger) - - client.onConnectionAck { _, client in - client.sendStart( - payload: GraphQLRequest( - query: """ - query { - hello - } - """ - ), - id: id - ) - } - client.onError { _, _ in - completeExpectation.fulfill() - } - client.onComplete { _, _ in - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onConnectionAck { _, client in + try await client.sendStart( + payload: GraphQLRequest( + query: """ + query { + hello + } + """ + ), + id: id + ) + } + client.onMessage { message, _ in + continuation.yield(message) + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } + client.onComplete { _, _ in + continuation.finish() + } } - client.onMessage { message, _ in - messages.append(message) + + try await client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) } - - client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) - - wait(for: [completeExpectation], timeout: 2) XCTAssertEqual( messages.count, 3, // 1 connection_ack, 1 data, 1 complete "Messages: \(messages.description)" ) } - + /// Test streaming message flow works as expected - func testStreaming() throws { + func testStreaming() async throws { let id = UUID().description - - // Test streaming conversation - var messages = [String]() - let completeExpectation = XCTestExpectation() - + var dataIndex = 1 let dataIndexMax = 3 - + let client = Client(messenger: clientMessenger) - client.onConnectionAck { _, client in - client.sendStart( - payload: GraphQLRequest( - query: """ - subscription { - hello - } - """ - ), - id: id - ) - - // Short sleep to allow for server to register subscription - usleep(3000) - - pubsub.onNext("hello \(dataIndex)") - } - client.onData { _, _ in - dataIndex = dataIndex + 1 - if dataIndex <= dataIndexMax { - pubsub.onNext("hello \(dataIndex)") - } else { - pubsub.onCompleted() + let messageStream = AsyncThrowingStream { continuation in + client.onConnectionAck { _, client in + try await client.sendStart( + payload: GraphQLRequest( + query: """ + subscription { + hello + } + """ + ), + id: id + ) + + // Wait until server has registered subscription + var i = 0 + while !self.subscribeReady, i < 50 { + usleep(1000) + i = i + 1 + } + if i == 50 { + XCTFail("Subscription timeout: Took longer than 50ms to set up") + } + + self.context.publisher.emit(event: "hello \(dataIndex)") + } + client.onData { _, _ in + dataIndex = dataIndex + 1 + if dataIndex <= dataIndexMax { + self.context.publisher.emit(event: "hello \(dataIndex)") + } else { + self.context.publisher.cancel() + } + } + client.onMessage { message, _ in + continuation.yield(message) + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } + client.onComplete { _, _ in + continuation.finish() } } - client.onError { _, _ in - completeExpectation.fulfill() - } - client.onComplete { _, _ in - completeExpectation.fulfill() - } - client.onMessage { message, _ in - messages.append(message) + + try await client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) } - - client.sendConnectionInit(payload: TokenInitPayload(authToken: "")) - - wait(for: [completeExpectation], timeout: 2) XCTAssertEqual( messages.count, 5, // 1 connection_ack, 3 data, 1 complete "Messages: \(messages.description)" ) } - + enum TestError: Error { case couldBeAnything } diff --git a/Tests/GraphQLWSTests/Utils/TestAPI.swift b/Tests/GraphQLWSTests/Utils/TestAPI.swift index 4b7ea03..8867da1 100644 --- a/Tests/GraphQLWSTests/Utils/TestAPI.swift +++ b/Tests/GraphQLWSTests/Utils/TestAPI.swift @@ -1,15 +1,11 @@ import Foundation -import GraphQL import Graphiti -import GraphQLRxSwift -import RxSwift - -let pubsub = PublishSubject() +import GraphQL struct TestAPI: API { let resolver = TestResolver() let context = TestContext() - + let schema = try! Schema { Query { Field("hello", at: TestResolver.hello) @@ -20,7 +16,9 @@ struct TestAPI: API { } } -final class TestContext { +final class TestContext: Sendable { + let publisher = SimplePubSub() + func hello() -> String { "world" } @@ -30,8 +28,48 @@ struct TestResolver { func hello(context: TestContext, arguments _: NoArguments) -> String { context.hello() } - - func subscribeHello(context: TestContext, arguments: NoArguments) -> EventStream { - pubsub.toEventStream() + + func subscribeHello(context: TestContext, arguments _: NoArguments) -> AsyncThrowingStream { + context.publisher.subscribe() } } + +/// A very simple publish/subscriber used for testing +class SimplePubSub: @unchecked Sendable { + private var subscribers: [Subscriber] + + init() { + subscribers = [] + } + + func emit(event: T) { + for subscriber in subscribers { + subscriber.callback(event) + } + } + + func cancel() { + for subscriber in subscribers { + subscriber.cancel() + } + } + + func subscribe() -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + let subscriber = Subscriber( + callback: { newValue in + continuation.yield(newValue) + }, + cancel: { + continuation.finish() + } + ) + subscribers.append(subscriber) + } + } +} + +struct Subscriber { + let callback: (T) -> Void + let cancel: () -> Void +} diff --git a/Tests/GraphQLWSTests/Utils/TestMessenger.swift b/Tests/GraphQLWSTests/Utils/TestMessenger.swift index 5d3fde3..803f080 100644 --- a/Tests/GraphQLWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLWSTests/Utils/TestMessenger.swift @@ -7,32 +7,28 @@ import Foundation /// /// Note that this only retains a weak reference to 'other', so the client should retain references /// or risk them being deinitialized early -class TestMessenger: Messenger { +class TestMessenger: Messenger, @unchecked Sendable { weak var other: TestMessenger? - var onReceive: (String) -> Void = { _ in } + var onReceive: (String) async throws -> Void = { _ in } let queue: DispatchQueue = .init(label: "Test messenger") - + init() {} - - func send(_ message: S) where S: Collection, S.Element == Character { + + func send(_ message: S) async throws where S: Collection, S.Element == Character { guard let other = other else { return } - - // Run the other message asyncronously to avoid nesting issues - queue.async { - other.onReceive(String(message)) - } + try await other.onReceive(String(message)) } - - func onReceive(callback: @escaping (String) -> Void) { - self.onReceive = callback + + func onReceive(callback: @escaping (String) async throws -> Void) { + onReceive = callback } - - func error(_ message: String, code: Int) { - self.send("\(code): \(message)") + + func error(_ message: String, code: Int) async throws { + try await send("\(code): \(message)") } - + func close() { // This is a testing no-op }