From 55bbbbb4bda5d2e58f7210d60fbca75c7d8951b3 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Sun, 27 Jul 2025 00:37:51 -0600 Subject: [PATCH 1/6] feat!: Converts to Swift Concurrency This adopts GraphQL v4, Graphiti v3, and removes the RxSwift and NIO dependencies. --- Package.resolved | 39 +-- Package.swift | 14 +- README.md | 36 +-- Sources/GraphQLTransportWS/Client.swift | 139 ++++---- .../GraphqlTransportWSError.swift | 36 +-- Sources/GraphQLTransportWS/InitPayloads.swift | 4 +- .../GraphQLTransportWS/JsonEncodable.swift | 3 +- Sources/GraphQLTransportWS/Messenger.swift | 19 +- Sources/GraphQLTransportWS/Requests.swift | 4 +- Sources/GraphQLTransportWS/Responses.swift | 34 +- Sources/GraphQLTransportWS/Server.swift | 304 ++++++++---------- .../GraphQLTransportWSTests.swift | 270 +++++++--------- .../Utils/TestAPI.swift | 56 +++- .../Utils/TestMessenger.swift | 28 +- 14 files changed, 473 insertions(+), 513 deletions(-) diff --git a/Package.resolved b/Package.resolved index 7ce8815..6f88d12 100644 --- a/Package.resolved +++ b/Package.resolved @@ -6,8 +6,8 @@ "repositoryURL": "https://github.com/GraphQLSwift/Graphiti.git", "state": { "branch": null, - "revision": "c9bc9d1cc9e62e71a824dc178630bfa8b8a6e2a4", - "version": "1.0.0" + "revision": "a23a3d232df202fc158ad2d698926325b470523c", + "version": "3.0.0" } }, { @@ -15,26 +15,8 @@ "repositoryURL": "https://github.com/GraphQLSwift/GraphQL.git", "state": { "branch": null, - "revision": "283cc4de56b994a00b2724328221b7a1bc846ddc", - "version": "2.2.1" - } - }, - { - "package": "GraphQLRxSwift", - "repositoryURL": "https://github.com/GraphQLSwift/GraphQLRxSwift.git", - "state": { - "branch": null, - "revision": "c7ec6595f92ef5d77c06852e4acc4cd46a753622", - "version": "0.0.4" - } - }, - { - "package": "RxSwift", - "repositoryURL": "https://github.com/ReactiveX/RxSwift.git", - "state": { - "branch": null, - "revision": "b4307ba0b6425c0ba4178e138799946c3da594f8", - "version": "6.5.0" + "revision": "eedec2bbfcfd0c10c2eaee8ac2f91bde5af28b8c", + "version": "4.0.0" } }, { @@ -42,17 +24,8 @@ "repositoryURL": "https://github.com/apple/swift-collections", "state": { "branch": null, - "revision": "48254824bb4248676bf7ce56014ff57b142b77eb", - "version": "1.0.2" - } - }, - { - "package": "swift-nio", - "repositoryURL": "https://github.com/apple/swift-nio.git", - "state": { - "branch": null, - "revision": "6aa9347d9bc5bbfe6a84983aec955c17ffea96ef", - "version": "2.33.0" + "revision": "8c0c0a8b49e080e54e5e328cc552821ff07cd341", + "version": "1.2.1" } } ] diff --git a/Package.swift b/Package.swift index 2d77e15..cdad552 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,7 @@ import PackageDescription let package = Package( name: "GraphQLTransportWS", + platforms: [.macOS(.v10_15)], products: [ .library( name: "GraphQLTransportWS", @@ -11,22 +12,17 @@ let package = Package( ), ], dependencies: [ - .package(name: "Graphiti", url: "https://github.com/GraphQLSwift/Graphiti.git", from: "1.0.0"), - .package(name: "GraphQL", url: "https://github.com/GraphQLSwift/GraphQL.git", from: "2.2.1"), - .package(name: "GraphQLRxSwift", url: "https://github.com/GraphQLSwift/GraphQLRxSwift.git", from: "0.0.4"), - .package(name: "RxSwift", url: "https://github.com/ReactiveX/RxSwift.git", from: "6.1.0"), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.33.0"), + .package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"), + .package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.0"), ], targets: [ .target( name: "GraphQLTransportWS", dependencies: [ .product(name: "Graphiti", package: "Graphiti"), - .product(name: "GraphQLRxSwift", package: "GraphQLRxSwift"), .product(name: "GraphQL", package: "GraphQL"), - .product(name: "NIO", package: "swift-nio"), - .product(name: "RxSwift", package: "RxSwift") - ]), + ] + ), .testTarget( name: "GraphQLTransportWSTests", dependencies: ["GraphQLTransportWS"] diff --git a/README.md b/README.md index 65b6646..f9d2d2b 100644 --- a/README.md +++ b/README.md @@ -27,32 +27,32 @@ import GraphQLTransportWS /// Messenger wrapper for WebSockets class WebSocketMessenger: Messenger { private weak var websocket: WebSocket? - private var onReceive: (String) -> Void = { _ in } - + private var onReceive: (String) async throws -> Void = { _ in } + init(websocket: WebSocket) { self.websocket = websocket websocket.onText { _, message in - self.onReceive(message) + try await self.onReceive(message) } } - - func send(_ message: S) where S: Collection, S.Element == Character { + + func send(_ message: S) where S: Collection, S.Element == Character async throws { guard let websocket = websocket else { return } - websocket.send(message) + try await websocket.send(message) } - - func onReceive(callback: @escaping (String) -> Void) { + + func onReceive(callback: @escaping (String) async throws -> Void) { self.onReceive = callback } - - func error(_ message: String, code: Int) { + + func error(_ message: String, code: Int) async throws { guard let websocket = websocket else { return } - websocket.send("\(code): \(message)") + try await websocket.send("\(code): \(message)") } - - func close() { + + func close() async throws { guard let websocket = websocket else { return } - _ = websocket.close() + try await websocket.close() } } ``` @@ -67,7 +67,7 @@ routes.webSocket( let server = GraphQLTransportWS.Server( messenger: messenger, onExecute: { graphQLRequest in - api.execute( + try await api.execute( request: graphQLRequest.query, context: context, on: self.eventLoop, @@ -76,7 +76,7 @@ routes.webSocket( ) }, onSubscribe: { graphQLRequest in - api.subscribe( + try await api.subscribe( request: graphQLRequest.query, context: context, on: self.eventLoop, @@ -128,8 +128,8 @@ If the `payload` field is not required on your server, you may make Server's gen ## Memory Management -Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket -implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the +Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket +implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes, diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index c4eaa62..bcd82f1 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -5,16 +5,16 @@ import GraphQL public class Client { // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - var onConnectionAck: (ConnectionAckResponse, Client) -> Void = { _, _ in } - var onNext: (NextResponse, Client) -> Void = { _, _ in } - var onError: (ErrorResponse, Client) -> Void = { _, _ in } - var onComplete: (CompleteResponse, Client) -> Void = { _, _ in } - var onMessage: (String, Client) -> Void = { _, _ in } - + + var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in } + var onNext: (NextResponse, Client) async throws -> Void = { _, _ in } + var onError: (ErrorResponse, Client) async throws -> Void = { _, _ in } + var onComplete: (CompleteResponse, Client) async throws -> Void = { _, _ in } + var onMessage: (String, Client) async throws -> Void = { _, _ in } + let encoder = GraphQLJSONEncoder() let decoder = JSONDecoder() - + /// Create a new client. /// /// - Parameters: @@ -24,123 +24,122 @@ public class Client { ) { self.messenger = messenger messenger.onReceive { message in - self.onMessage(message, self) - + try await self.onMessage(message, self) + // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages return } - + guard let json = message.data(using: .utf8) else { - self.error(.invalidEncoding()) + try await self.error(.invalidEncoding()) return } - + let response: Response do { response = try self.decoder.decode(Response.self, from: json) - } - catch { - self.error(.noType()) + } catch { + try await self.error(.noType()) return } - + switch response.type { - case .connectionAck: - guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .connectionAck)) - return - } - self.onConnectionAck(connectionAckResponse, self) - case .next: - guard let nextResponse = try? self.decoder.decode(NextResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .next)) - return - } - self.onNext(nextResponse, self) - case .error: - guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .error)) - return - } - self.onError(errorResponse, self) - case .complete: - guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { - self.error(.invalidResponseFormat(messageType: .complete)) - return - } - self.onComplete(completeResponse, self) - case .unknown: - self.error(.invalidType()) + case .connectionAck: + guard let connectionAckResponse = try? self.decoder.decode(ConnectionAckResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .connectionAck)) + return + } + try await self.onConnectionAck(connectionAckResponse, self) + case .next: + guard let nextResponse = try? self.decoder.decode(NextResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .next)) + return + } + try await self.onNext(nextResponse, self) + case .error: + guard let errorResponse = try? self.decoder.decode(ErrorResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .error)) + return + } + try await self.onError(errorResponse, self) + case .complete: + guard let completeResponse = try? self.decoder.decode(CompleteResponse.self, from: json) else { + try await self.error(.invalidResponseFormat(messageType: .complete)) + return + } + try await self.onComplete(completeResponse, self) + case .unknown: + try await self.error(.invalidType()) } } } - + /// Define the callback run on receipt of a `connection_ack` message /// - Parameter callback: The callback to assign - public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) -> Void) { - self.onConnectionAck = callback + public func onConnectionAck(_ callback: @escaping (ConnectionAckResponse, Client) async throws -> Void) { + onConnectionAck = callback } - + /// Define the callback run on receipt of a `next` message /// - Parameter callback: The callback to assign - public func onNext(_ callback: @escaping (NextResponse, Client) -> Void) { - self.onNext = callback + public func onNext(_ callback: @escaping (NextResponse, Client) async throws -> Void) { + onNext = callback } - + /// Define the callback run on receipt of an `error` message /// - Parameter callback: The callback to assign - public func onError(_ callback: @escaping (ErrorResponse, Client) -> Void) { - self.onError = callback + public func onError(_ callback: @escaping (ErrorResponse, Client) async throws -> Void) { + onError = callback } - + /// Define the callback run on receipt of a `complete` message /// - Parameter callback: The callback to assign - public func onComplete(_ callback: @escaping (CompleteResponse, Client) -> Void) { - self.onComplete = callback + public func onComplete(_ callback: @escaping (CompleteResponse, Client) async throws -> Void) { + onComplete = callback } - + /// Define the callback run on receipt of any message /// - Parameter callback: The callback to assign - public func onMessage(_ callback: @escaping (String, Client) -> Void) { - self.onMessage = callback + public func onMessage(_ callback: @escaping (String, Client) async throws -> Void) { + onMessage = callback } - + /// Send a `connection_init` request through the messenger - public func sendConnectionInit(payload: InitPayload) { + public func sendConnectionInit(payload: InitPayload) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionInitRequest( payload: payload ).toJSON(encoder) ) } - + /// Send a `subscribe` request through the messenger - public func sendStart(payload: GraphQLRequest, id: String) { + public func sendStart(payload: GraphQLRequest, id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( SubscribeRequest( payload: payload, id: id ).toJSON(encoder) ) } - + /// Send a `complete` request through the messenger - public func sendStop(id: String) { + public func sendStop(id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( CompleteRequest( id: id ).toJSON(encoder) ) } - + /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLTransportWSError) { + private func error(_ error: GraphQLTransportWSError) async throws { guard let messenger = messenger else { return } - messenger.error(error.message, code: error.code.rawValue) + try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index 2ccf56f..5dc1d78 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -3,82 +3,82 @@ import GraphQL struct GraphQLTransportWSError: Error { let message: String let code: ErrorCode - + init(_ message: String, code: ErrorCode) { self.message = message self.code = code } - + static func unauthorized() -> Self { return self.init( "Unauthorized", code: .unauthorized ) } - + static func notInitialized() -> Self { return self.init( "Connection not initialized", code: .notInitialized ) } - + static func tooManyInitializations() -> Self { return self.init( "Too many initialisation requests", code: .tooManyInitializations ) } - + static func subscriberAlreadyExists(id: String) -> Self { return self.init( "Subscriber for \(id) already exists", code: .subscriberAlreadyExists ) } - + static func invalidEncoding() -> Self { return self.init( "Message was not encoded in UTF8", code: .invalidEncoding ) } - + static func noType() -> Self { return self.init( "Message has no 'type' field", code: .noType ) } - + static func invalidType() -> Self { return self.init( "Message 'type' value does not match supported types", code: .invalidType ) } - + static func invalidRequestFormat(messageType: RequestMessageType) -> Self { return self.init( "Request message doesn't match '\(messageType.rawValue)' JSON format", code: .invalidRequestFormat ) } - + static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { return self.init( "Response message doesn't match '\(messageType.rawValue)' JSON format", code: .invalidResponseFormat ) } - + static func internalAPIStreamIssue(errors: [GraphQLError]) -> Self { return self.init( - "API Response did not result in a stream type, contained errors\n \(errors.map { $0.message}.joined(separator: "\n"))", + "API Response did not result in a stream type, contained errors\n \(errors.map { $0.message }.joined(separator: "\n"))", code: .internalAPIStreamIssue ) } - + static func graphQLError(_ error: Error) -> Self { return self.init( "\(error)", @@ -91,25 +91,25 @@ struct GraphQLTransportWSError: Error { public enum ErrorCode: Int, CustomStringConvertible { // Miscellaneous case miscellaneous = 4400 - + // Internal errors case graphQLError = 4401 case internalAPIStreamIssue = 4402 - + // Message errors case invalidEncoding = 4410 case noType = 4411 case invalidType = 4412 case invalidRequestFormat = 4413 case invalidResponseFormat = 4414 - + // Initialization errors case unauthorized = 4430 case notInitialized = 4431 case tooManyInitializations = 4432 case subscriberAlreadyExists = 4433 - + public var description: String { - return "\(self.rawValue)" + return "\(rawValue)" } } diff --git a/Sources/GraphQLTransportWS/InitPayloads.swift b/Sources/GraphQLTransportWS/InitPayloads.swift index 41a6cc2..966c3a9 100644 --- a/Sources/GraphQLTransportWS/InitPayloads.swift +++ b/Sources/GraphQLTransportWS/InitPayloads.swift @@ -1,12 +1,12 @@ // Contains convenient `connection_init` payloads for users of this package /// `connection_init` `payload` that is empty -public struct EmptyInitPayload: Equatable & Codable { } +public struct EmptyInitPayload: Equatable & Codable {} /// `connection_init` `payload` that includes an `authToken` field public struct TokenInitPayload: Equatable & Codable { public let authToken: String - + public init(authToken: String) { self.authToken = authToken } diff --git a/Sources/GraphQLTransportWS/JsonEncodable.swift b/Sources/GraphQLTransportWS/JsonEncodable.swift index 51c8673..b54f881 100644 --- a/Sources/GraphQLTransportWS/JsonEncodable.swift +++ b/Sources/GraphQLTransportWS/JsonEncodable.swift @@ -12,8 +12,7 @@ extension JsonEncodable { let data: Data do { data = try encoder.encode(self) - } - catch { + } catch { return EncodingErrorResponse("Unable to encode response").toJSON(encoder) } guard let body = String(data: data, encoding: .utf8) else { diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 7e01402..3a9c157 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -1,23 +1,22 @@ import Foundation -import NIO -/// Protocol for an object that can send and recieve messages. This allows mocking in tests. +/// Protocol for an object that can send and recieve messages. This allows mocking in tests public protocol Messenger: AnyObject { // AnyObject compliance requires that the implementing object is a class and we can reference it weakly - + /// Send a message through this messenger /// - Parameter message: The message to send - func send(_ message: S) -> Void where S: Collection, S.Element == Character - + func send(_ message: S) async throws -> Void where S: Collection, S.Element == Character + /// Set the callback that should be run when a message is recieved - func onReceive(callback: @escaping (String) -> Void) -> Void - + func onReceive(callback: @escaping (String) async throws -> Void) + /// Close the messenger - func close() -> Void - + func close() async throws + /// Indicate that the messenger experienced an error. /// - Parameters: /// - message: The message describing the error /// - code: An error code - func error(_ message: String, code: Int) -> Void + func error(_ message: String, code: Int) async throws } diff --git a/Sources/GraphQLTransportWS/Requests.swift b/Sources/GraphQLTransportWS/Requests.swift index 48d474b..98267ca 100644 --- a/Sources/GraphQLTransportWS/Requests.swift +++ b/Sources/GraphQLTransportWS/Requests.swift @@ -42,8 +42,8 @@ enum RequestMessageType: String, Codable { case subscribe case complete case unknown - - public init(from decoder: Decoder) throws { + + init(from decoder: Decoder) throws { guard let value = try? decoder.singleValueContainer().decode(String.self) else { self = .unknown return diff --git a/Sources/GraphQLTransportWS/Responses.swift b/Sources/GraphQLTransportWS/Responses.swift index d8cab1e..34da4fd 100644 --- a/Sources/GraphQLTransportWS/Responses.swift +++ b/Sources/GraphQLTransportWS/Responses.swift @@ -10,9 +10,9 @@ struct Response: Equatable, JsonEncodable { public struct ConnectionAckResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [String: Map]? - + init(_ payload: [String: Map]? = nil) { - self.type = .connectionAck + type = .connectionAck self.payload = payload } } @@ -22,9 +22,9 @@ public struct NextResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: GraphQLResult? public let id: String - + init(_ payload: GraphQLResult? = nil, id: String) { - self.type = .next + type = .next self.payload = payload self.id = id } @@ -34,9 +34,9 @@ public struct NextResponse: Equatable, JsonEncodable { public struct CompleteResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let id: String - + init(id: String) { - self.type = .complete + type = .complete self.id = id } } @@ -46,18 +46,18 @@ public struct ErrorResponse: Equatable, JsonEncodable { let type: ResponseMessageType public let payload: [GraphQLError] public let id: String - + init(_ errors: [Error], id: String) { let graphQLErrors = errors.map { error -> GraphQLError in switch error { - case let graphQLError as GraphQLError: - return graphQLError - default: - return GraphQLError(error) + case let graphQLError as GraphQLError: + return graphQLError + default: + return GraphQLError(error) } } - self.type = .error - self.payload = graphQLErrors + type = .error + payload = graphQLErrors self.id = id } } @@ -69,7 +69,7 @@ enum ResponseMessageType: String, Codable { case error case complete case unknown - + init(from decoder: Decoder) throws { guard let value = try? decoder.singleValueContainer().decode(String.self) else { self = .unknown @@ -84,9 +84,9 @@ enum ResponseMessageType: String, Codable { struct EncodingErrorResponse: Equatable, Codable, JsonEncodable { let type: ResponseMessageType let payload: [String: String] - + init(_ errorMessage: String) { - self.type = .error - self.payload = ["error": errorMessage] + type = .error + payload = ["error": errorMessage] } } diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index b18c2ec..605df8f 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -1,8 +1,5 @@ import Foundation import GraphQL -import GraphQLRxSwift -import NIO -import RxSwift /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// @@ -10,267 +7,250 @@ import RxSwift public class Server { // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? - - let onExecute: (GraphQLRequest) -> EventLoopFuture - let onSubscribe: (GraphQLRequest) -> EventLoopFuture - var auth: (InitPayload) throws -> EventLoopFuture - - var onExit: () -> Void = { } - var onOperationComplete: (String) -> Void = { _ in } - var onOperationError: (String) -> Void = { _ in } - var onMessage: (String) -> Void = { _ in } - + + let onExecute: (GraphQLRequest) async throws -> GraphQLResult + let onSubscribe: (GraphQLRequest) async throws -> Result, GraphQLErrors> + var auth: (InitPayload) async throws -> Void + + var onExit: () async throws -> Void = {} + var onMessage: (String) async throws -> Void = { _ in } + var onOperationComplete: (String) async throws -> Void = { _ in } + var onOperationError: (String, [Error]) async throws -> Void = { _, _ in } + var initialized = false - - let disposeBag = DisposeBag() - let encoder = GraphQLJSONEncoder() + let decoder = JSONDecoder() - + let encoder = GraphQLJSONEncoder() + + private var subscriptionTasks = [String: Task]() + /// Create a new server /// /// - Parameters: /// - messenger: The messenger to bind the server to. - /// - onExecute: Callback run during `subscribe` resolution for non-streaming queries. Typically this is `API.execute`. - /// - onSubscribe: Callback run during `subscribe` resolution for streaming queries. Typically this is `API.subscribe`. - /// - eventLoop: EventLoop on which to perform server operations. + /// - onExecute: Callback run during `start` resolution for non-streaming queries. Typically this is `API.execute`. + /// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`. public init( messenger: Messenger, - onExecute: @escaping (GraphQLRequest) -> EventLoopFuture, - onSubscribe: @escaping (GraphQLRequest) -> EventLoopFuture, - eventLoop: EventLoop + onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, + onSubscribe: @escaping (GraphQLRequest) async throws -> Result, GraphQLErrors> ) { self.messenger = messenger self.onExecute = onExecute self.onSubscribe = onSubscribe - self.auth = { _ in eventLoop.makeSucceededVoidFuture() } - + auth = { _ in } + messenger.onReceive { message in - self.onMessage(message) - + guard let messenger = self.messenger else { return } + + try await self.onMessage(message) + // Detect and ignore error responses. if message.starts(with: "44") { // TODO: Determine what to do with returned error messages return } - - guard let data = message.data(using: .utf8) else { - self.error(.invalidEncoding()) + + guard let json = message.data(using: .utf8) else { + try await self.error(.invalidEncoding()) return } - + let request: Request do { - request = try self.decoder.decode(Request.self, from: data) - } - catch { - self.error(.noType()) + request = try self.decoder.decode(Request.self, from: json) + } catch { + try await self.error(.noType()) return } - + // handle incoming message switch request.type { - case .connectionInit: - guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: data) else { - self.error(.invalidRequestFormat(messageType: .connectionInit)) - return - } - self.onConnectionInit(connectionInitRequest) - case .subscribe: - guard let subscribeRequest = try? self.decoder.decode(SubscribeRequest.self, from: data) else { - self.error(.invalidRequestFormat(messageType: .subscribe)) - return - } - self.onSubscribe(subscribeRequest) - case .complete: - guard let completeRequest = try? self.decoder.decode(CompleteRequest.self, from: data) else { - self.error(.invalidRequestFormat(messageType: .complete)) - return - } - self.onOperationComplete(completeRequest.id) - case .unknown: - self.error(.invalidType()) + case .connectionInit: + guard let connectionInitRequest = try? self.decoder.decode(ConnectionInitRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .connectionInit)) + return + } + try await self.onConnectionInit(connectionInitRequest, messenger) + case .subscribe: + guard let subscribeRequest = try? self.decoder.decode(SubscribeRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .subscribe)) + return + } + try await self.onSubscribe(subscribeRequest) + case .complete: + guard let completeRequest = try? self.decoder.decode(CompleteRequest.self, from: json) else { + try await self.error(.invalidRequestFormat(messageType: .complete)) + return + } + try await self.onOperationComplete(completeRequest.id) + case .unknown: + try await self.error(.invalidType()) } } } - - /// Define the callback run during `connection_init` resolution that allows authorization using the `payload`. - /// Throw or fail the future to indicate that authorization has failed. - /// - Parameter callback: The callback to assign - public func auth(_ callback: @escaping (InitPayload) throws -> EventLoopFuture) { - self.auth = callback + + /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. + /// Throw from this closure to indicate that authorization has failed. + /// - Parameter callback: The callback to assign + public func auth(_ callback: @escaping (InitPayload) async throws -> Void) { + auth = callback } - + /// Define the callback run when the communication is shut down, either by the client or server /// - Parameter callback: The callback to assign public func onExit(_ callback: @escaping () -> Void) { - self.onExit = callback + onExit = callback } - + /// Define the callback run on receipt of any message /// - Parameter callback: The callback to assign public func onMessage(_ callback: @escaping (String) -> Void) { - self.onMessage = callback + onMessage = callback } - + /// Define the callback run on the completion a full operation (query/mutation, end of subscription) - /// - Parameter callback: The callback to assign, taking a string parameter for the ID of the operation + /// - Parameter callback: The callback to assign public func onOperationComplete(_ callback: @escaping (String) -> Void) { - self.onOperationComplete = callback + onOperationComplete = callback } - + /// Define the callback to run on error of any full operation (failed query, interrupted subscription) - /// - Parameter callback: The callback to assign, taking a string parameter for the ID of the operation - public func onOperationError(_ callback: @escaping (String) -> Void) { - self.onOperationError = callback + /// - Parameter callback: The callback to assign + public func onOperationError(_ callback: @escaping (String, [Error]) -> Void) { + onOperationError = callback } - - private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest) { + + private func onConnectionInit(_ connectionInitRequest: ConnectionInitRequest, _: Messenger) async throws { guard !initialized else { - self.error(.tooManyInitializations()) + try await error(.tooManyInitializations()) return } - + do { - let authResult = try self.auth(connectionInitRequest.payload) - authResult.whenSuccess { - self.initialized = true - self.sendConnectionAck() - } - authResult.whenFailure { error in - self.error(.unauthorized()) - return - } - } - catch { - self.error(.unauthorized()) + try await auth(connectionInitRequest.payload) + } catch { + try await self.error(.unauthorized()) return } + initialized = true + try await sendConnectionAck() + // TODO: Should we send the `ka` message? } - - private func onSubscribe(_ subscribeRequest: SubscribeRequest) { + + private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws { guard initialized else { - self.error(.notInitialized()) + try await error(.notInitialized()) return } - + let id = subscribeRequest.id + if subscriptionTasks[id] != nil { + try await error(.subscriberAlreadyExists(id: id)) + } + let graphQLRequest = subscribeRequest.payload - + var isStreaming = false do { isStreaming = try graphQLRequest.isSubscription() - } - catch { - self.sendError(error, id: id) + } catch { + try await sendError(error, id: id) return } - + if isStreaming { - let subscribeFuture = onSubscribe(graphQLRequest) - subscribeFuture.whenSuccess { [weak self] result in - guard let self = self else { return } - guard let streamOpt = result.stream else { - // API issue - subscribe resolver isn't stream - self.sendError(result.errors, id: id) + do { + let result = try await onSubscribe(graphQLRequest) + let stream: AsyncThrowingStream + do { + stream = try result.get() + } catch { + try await sendError(error, id: id) return } - let stream = streamOpt as! ObservableSubscriptionEventStream - let observable = stream.observable - - observable.subscribe( - onNext: { [weak self] resultFuture in - guard let self = self else { return } - resultFuture.whenSuccess { result in - self.sendNext(result, id: id) - } - resultFuture.whenFailure { error in - self.sendError(error, id: id) + subscriptionTasks[id] = Task { + for try await event in stream { + try Task.checkCancellation() + do { + try await self.sendNext(event, id: id) + } catch { + try await self.sendError(error, id: id) + throw error } - }, - onError: { [weak self] error in - guard let self = self else { return } - self.sendError(error, id: id) - }, - onCompleted: { [weak self] in - guard let self = self else { return } - self.sendComplete(id: id) } - ).disposed(by: self.disposeBag) - } - subscribeFuture.whenFailure { error in - self.sendError(error, id: id) - } - } - else { - let executeFuture = onExecute(graphQLRequest) - executeFuture.whenSuccess { result in - self.sendNext(result, id: id) - self.sendComplete(id: id) - self.messenger?.close() + try await self.sendComplete(id: id) + } + } catch { + try await sendError(error, id: id) } - executeFuture.whenFailure { error in - self.sendError(error, id: id) - self.sendComplete(id: id) - self.messenger?.close() + } else { + do { + let result = try await onExecute(graphQLRequest) + try await sendNext(result, id: id) + try await sendComplete(id: id) + } catch { + try await sendError(error, id: id) } + try await messenger?.close() } } - + /// Send a `connection_ack` response through the messenger - private func sendConnectionAck(_ payload: [String: Map]? = nil) { + private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ConnectionAckResponse(payload).toJSON(encoder) ) } - + /// Send a `next` response through the messenger - private func sendNext(_ payload: GraphQLResult? = nil, id: String) { + private func sendNext(_ payload: GraphQLResult? = nil, id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( NextResponse( payload, id: id ).toJSON(encoder) ) } - + /// Send a `complete` response through the messenger - private func sendComplete(id: String) { + private func sendComplete(id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( CompleteResponse( id: id ).toJSON(encoder) ) - self.onOperationComplete(id) + try await onOperationComplete(id) } - + /// Send an `error` response through the messenger - private func sendError(_ errors: [Error], id: String) { + private func sendError(_ errors: [Error], id: String) async throws { guard let messenger = messenger else { return } - messenger.send( + try await messenger.send( ErrorResponse( errors, id: id ).toJSON(encoder) ) - self.onOperationError(id) + try await onOperationError(id, errors) } - + /// Send an `error` response through the messenger - private func sendError(_ error: Error, id: String) { - self.sendError([error], id: id) + private func sendError(_ error: Error, id: String) async throws { + try await sendError([error], id: id) } - + /// Send an `error` response through the messenger - private func sendError(_ errorMessage: String, id: String) { - self.sendError(GraphQLError(message: errorMessage), id: id) + private func sendError(_ errorMessage: String, id: String) async throws { + try await sendError(GraphQLError(message: errorMessage), id: id) } - + /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLTransportWSError) { + private func error(_ error: GraphQLTransportWSError) async throws { guard let messenger = messenger else { return } - messenger.error(error.message, code: error.code.rawValue) + try await messenger.error(error.message, code: error.code.rawValue) } } diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 72f067d..3efc67a 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -1,7 +1,6 @@ import Foundation import GraphQL -import NIO import XCTest @testable import GraphQLTransportWS @@ -10,228 +9,209 @@ class GraphqlTransportWSTests: XCTestCase { var clientMessenger: TestMessenger! var serverMessenger: TestMessenger! var server: Server! - var eventLoop: EventLoop! - + var context: TestContext! + override func setUp() { // Point the client and server at each other clientMessenger = TestMessenger() serverMessenger = TestMessenger() clientMessenger.other = serverMessenger serverMessenger.other = clientMessenger - - eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() + let api = TestAPI() let context = TestContext() - + server = Server( messenger: serverMessenger, onExecute: { graphQLRequest in - api.execute( + try await api.execute( request: graphQLRequest.query, - context: context, - on: self.eventLoop + context: context ) }, onSubscribe: { graphQLRequest in - api.subscribe( + try await api.subscribe( request: graphQLRequest.query, - context: context, - on: self.eventLoop + context: context ) - }, - eventLoop: self.eventLoop + } ) + self.context = context } - + /// Tests that trying to run methods before `connection_init` is not allowed - func testInitialize() throws { - var messages = [String]() - let completeExpectation = XCTestExpectation() - + func testInitialize() async throws { let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onMessage { message, _ in + continuation.yield(message) + // Expect only one message + continuation.finish() + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } } - - client.sendStart( + + try await client.sendStart( payload: GraphQLRequest( query: """ - query { - hello - } - """ + query { + hello + } + """ ), id: UUID().uuidString ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages, ["\(ErrorCode.notInitialized): Connection not initialized"] ) } - + /// Tests that throwing in the authorization callback forces an unauthorized error - func testAuthWithThrow() throws { - server.auth { payload in + func testAuthWithThrow() async throws { + server.auth { _ in throw TestError.couldBeAnything } - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + let messageStream = AsyncThrowingStream { continuation in + client.onMessage { message, _ in + continuation.yield(message) + // Expect only one message + continuation.finish() + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) - XCTAssertEqual( - messages, - ["\(ErrorCode.unauthorized): Unauthorized"] - ) - } - - /// Tests that failing a future in the authorization callback forces an unauthorized error - func testAuthWithFailedFuture() throws { - server.auth { payload in - self.eventLoop.makeFailedFuture(TestError.couldBeAnything) - } - - var messages = [String]() - let completeExpectation = XCTestExpectation() - - let client = Client(messenger: clientMessenger) - client.onMessage { message, _ in - messages.append(message) - completeExpectation.fulfill() + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) } - - client.sendConnectionInit( - payload: TokenInitPayload( - authToken: "" - ) - ) - - wait(for: [completeExpectation], timeout: 2) XCTAssertEqual( messages, ["\(ErrorCode.unauthorized): Unauthorized"] ) } - + /// Tests a single-op conversation - func testSingleOp() throws { + func testSingleOp() async throws { let id = UUID().description - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + let client = Client(messenger: clientMessenger) - client.onConnectionAck { _, client in - client.sendStart( - payload: GraphQLRequest( - query: """ - query { - hello - } - """ - ), - id: id - ) - } - client.onError { _, _ in - completeExpectation.fulfill() - } - client.onComplete { _, _ in - completeExpectation.fulfill() - } - client.onMessage { message, _ in - messages.append(message) + let messageStream = AsyncThrowingStream { continuation in + client.onConnectionAck { _, client in + try await client.sendStart( + payload: GraphQLRequest( + query: """ + query { + hello + } + """ + ), + id: id + ) + } + client.onMessage { message, _ in + continuation.yield(message) + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } + client.onComplete { _, _ in + continuation.finish() + } } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages.count, 3, // 1 connection_ack, 1 next, 1 complete "Messages: \(messages.description)" ) } - + /// Tests a streaming conversation - func testStreaming() throws { + func testStreaming() async throws { let id = UUID().description - - var messages = [String]() - let completeExpectation = XCTestExpectation() - + var dataIndex = 1 let dataIndexMax = 3 - + let client = Client(messenger: clientMessenger) - client.onConnectionAck { _, client in - client.sendStart( - payload: GraphQLRequest( - query: """ - subscription { - hello - } - """ - ), - id: id - ) - - // Short sleep to allow for server to register subscription - usleep(3000) - - pubsub.onNext("hello \(dataIndex)") - } - client.onNext { _, _ in - dataIndex = dataIndex + 1 - if dataIndex <= dataIndexMax { - pubsub.onNext("hello \(dataIndex)") - } else { - pubsub.onCompleted() + let messageStream = AsyncThrowingStream { continuation in + client.onConnectionAck { _, client in + try await client.sendStart( + payload: GraphQLRequest( + query: """ + subscription { + hello + } + """ + ), + id: id + ) + + // Short sleep to allow for server to register subscription + usleep(3000) + + self.context.publisher.emit(event: "hello \(dataIndex)") + } + client.onNext { _, _ in + dataIndex = dataIndex + 1 + if dataIndex <= dataIndexMax { + self.context.publisher.emit(event: "hello \(dataIndex)") + } else { + self.context.publisher.cancel() + } + } + client.onMessage { message, _ in + continuation.yield(message) + } + client.onError { message, _ in + continuation.finish(throwing: message.payload[0]) + } + client.onComplete { _, _ in + continuation.finish() } } - client.onError { _, _ in - completeExpectation.fulfill() - } - client.onComplete { _, _ in - completeExpectation.fulfill() - } - client.onMessage { message, _ in - messages.append(message) - } - - client.sendConnectionInit( + + try await client.sendConnectionInit( payload: TokenInitPayload( authToken: "" ) ) - - wait(for: [completeExpectation], timeout: 2) + + let messages = try await messageStream.reduce(into: [String]()) { result, message in + result.append(message) + } XCTAssertEqual( messages.count, 5, // 1 connection_ack, 3 next, 1 complete "Messages: \(messages.description)" ) } - + enum TestError: Error { case couldBeAnything } diff --git a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift index 4b7ea03..6acd9d8 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift @@ -1,15 +1,11 @@ import Foundation -import GraphQL import Graphiti -import GraphQLRxSwift -import RxSwift - -let pubsub = PublishSubject() +import GraphQL struct TestAPI: API { let resolver = TestResolver() let context = TestContext() - + let schema = try! Schema { Query { Field("hello", at: TestResolver.hello) @@ -21,6 +17,8 @@ struct TestAPI: API { } final class TestContext { + let publisher = SimplePubSub() + func hello() -> String { "world" } @@ -30,8 +28,48 @@ struct TestResolver { func hello(context: TestContext, arguments _: NoArguments) -> String { context.hello() } - - func subscribeHello(context: TestContext, arguments: NoArguments) -> EventStream { - pubsub.toEventStream() + + func subscribeHello(context: TestContext, arguments _: NoArguments) -> AsyncThrowingStream { + context.publisher.subscribe() } } + +/// A very simple publish/subscriber used for testing +class SimplePubSub { + private var subscribers: [Subscriber] + + init() { + subscribers = [] + } + + func emit(event: T) { + for subscriber in subscribers { + subscriber.callback(event) + } + } + + func cancel() { + for subscriber in subscribers { + subscriber.cancel() + } + } + + func subscribe() -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + let subscriber = Subscriber( + callback: { newValue in + continuation.yield(newValue) + }, + cancel: { + continuation.finish() + } + ) + subscribers.append(subscriber) + } + } +} + +struct Subscriber { + let callback: (T) -> Void + let cancel: () -> Void +} diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index 1e7c274..b600676 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -9,30 +9,26 @@ import Foundation /// or risk them being deinitialized early class TestMessenger: Messenger { weak var other: TestMessenger? - var onReceive: (String) -> Void = { _ in } + var onReceive: (String) async throws -> Void = { _ in } let queue: DispatchQueue = .init(label: "Test messenger") - + init() {} - - func send(_ message: S) where S: Collection, S.Element == Character { + + func send(_ message: S) async throws where S: Collection, S.Element == Character { guard let other = other else { return } - - // Run the other message asyncronously to avoid nesting issues - queue.async { - other.onReceive(String(message)) - } + try await other.onReceive(String(message)) } - - func onReceive(callback: @escaping (String) -> Void) { - self.onReceive = callback + + func onReceive(callback: @escaping (String) async throws -> Void) { + onReceive = callback } - - func error(_ message: String, code: Int) { - self.send("\(code): \(message)") + + func error(_ message: String, code: Int) async throws { + try await send("\(code): \(message)") } - + func close() { // This is a testing no-op } From 39a3f43e13c4b0adeced8991360ecddaf127f984 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Mon, 25 Aug 2025 13:23:12 -0600 Subject: [PATCH 2/6] ci: Uses centralized CI --- .github/workflows/test.yml | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b67469b..d51838a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,16 +1,14 @@ name: test on: + push: + branches: [ main ] pull_request: - push: { branches: [ main ] } - + branches: [ main ] + workflow_dispatch: jobs: + lint: + uses: graphqlswift/ci/.github/workflows/lint.yaml@main test: - strategy: - matrix: - os: [ubuntu-latest, macos-latest] - runs-on: ${{ matrix.os }} - steps: - - uses: fwal/setup-swift@v1 - - uses: actions/checkout@v2 - - name: Run tests - run: swift test + uses: graphqlswift/ci/.github/workflows/test.yaml@main + with: + include_android: false From 1e5738a9baba817da46343d4fb19e1370cb47e52 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Mon, 1 Sep 2025 22:40:42 -0600 Subject: [PATCH 3/6] feat: Enables strict concurrency --- Package.swift | 3 ++- Sources/GraphQLTransportWS/GraphqlTransportWSError.swift | 2 +- Sources/GraphQLTransportWS/Server.swift | 2 +- Tests/GraphQLTransportWSTests/Utils/TestAPI.swift | 4 ++-- Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift | 4 ++-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Package.swift b/Package.swift index cdad552..6d41154 100644 --- a/Package.swift +++ b/Package.swift @@ -27,5 +27,6 @@ let package = Package( name: "GraphQLTransportWSTests", dependencies: ["GraphQLTransportWS"] ), - ] + ], + swiftLanguageVersions: [.v5, .version("6")] ) diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index 5dc1d78..3fda638 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -88,7 +88,7 @@ struct GraphQLTransportWSError: Error { } /// Error codes for miscellaneous issues -public enum ErrorCode: Int, CustomStringConvertible { +public enum ErrorCode: Int, CustomStringConvertible, Sendable { // Miscellaneous case miscellaneous = 4400 diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 605df8f..e434269 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -4,7 +4,7 @@ import GraphQL /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server { +public class Server: @unchecked Sendable { // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? diff --git a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift index 6acd9d8..8867da1 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestAPI.swift @@ -16,7 +16,7 @@ struct TestAPI: API { } } -final class TestContext { +final class TestContext: Sendable { let publisher = SimplePubSub() func hello() -> String { @@ -35,7 +35,7 @@ struct TestResolver { } /// A very simple publish/subscriber used for testing -class SimplePubSub { +class SimplePubSub: @unchecked Sendable { private var subscribers: [Subscriber] init() { diff --git a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift index b600676..a35aa09 100644 --- a/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift +++ b/Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift @@ -7,14 +7,14 @@ import Foundation /// /// Note that this only retains a weak reference to 'other', so the client should retain references /// or risk them being deinitialized early -class TestMessenger: Messenger { +class TestMessenger: Messenger, @unchecked Sendable { weak var other: TestMessenger? var onReceive: (String) async throws -> Void = { _ in } let queue: DispatchQueue = .init(label: "Test messenger") init() {} - func send(_ message: S) async throws where S: Collection, S.Element == Character { + func send(_ message: S) async throws where S: Collection, S.Element == Character { guard let other = other else { return } From 5016ca8dce166edc23abafb564305ba992a078b2 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Mon, 22 Sep 2025 16:50:16 -0600 Subject: [PATCH 4/6] feat!: allow any AsyncSequence type in subscription This allows intermediate `map`/`filter` requests to occur before binding the GraphQL subscription to the messenger. --- Package.resolved | 61 +++++++++---------- Package.swift | 2 +- Sources/GraphQLTransportWS/InitPayloads.swift | 4 +- Sources/GraphQLTransportWS/Server.swift | 30 ++++----- .../GraphQLTransportWSTests.swift | 6 +- 5 files changed, 51 insertions(+), 52 deletions(-) diff --git a/Package.resolved b/Package.resolved index 6f88d12..3e34fcf 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,34 +1,33 @@ { - "object": { - "pins": [ - { - "package": "Graphiti", - "repositoryURL": "https://github.com/GraphQLSwift/Graphiti.git", - "state": { - "branch": null, - "revision": "a23a3d232df202fc158ad2d698926325b470523c", - "version": "3.0.0" - } - }, - { - "package": "GraphQL", - "repositoryURL": "https://github.com/GraphQLSwift/GraphQL.git", - "state": { - "branch": null, - "revision": "eedec2bbfcfd0c10c2eaee8ac2f91bde5af28b8c", - "version": "4.0.0" - } - }, - { - "package": "swift-collections", - "repositoryURL": "https://github.com/apple/swift-collections", - "state": { - "branch": null, - "revision": "8c0c0a8b49e080e54e5e328cc552821ff07cd341", - "version": "1.2.1" - } + "originHash" : "30951f6d77c03868bb74b0838ce93637391a168c6668a029c8a8a1dd9fb01aa5", + "pins" : [ + { + "identity" : "graphiti", + "kind" : "remoteSourceControl", + "location" : "https://github.com/GraphQLSwift/Graphiti.git", + "state" : { + "revision" : "a23a3d232df202fc158ad2d698926325b470523c", + "version" : "3.0.0" } - ] - }, - "version": 1 + }, + { + "identity" : "graphql", + "kind" : "remoteSourceControl", + "location" : "https://github.com/GraphQLSwift/GraphQL.git", + "state" : { + "revision" : "0fe18bc0bbbc9ab8929c285f419adea7c8fc7da2", + "version" : "4.0.1" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections", + "state" : { + "revision" : "8c0c0a8b49e080e54e5e328cc552821ff07cd341", + "version" : "1.2.1" + } + } + ], + "version" : 3 } diff --git a/Package.swift b/Package.swift index 6d41154..db11fb4 100644 --- a/Package.swift +++ b/Package.swift @@ -13,7 +13,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"), - .package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.0"), + .package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.1"), ], targets: [ .target( diff --git a/Sources/GraphQLTransportWS/InitPayloads.swift b/Sources/GraphQLTransportWS/InitPayloads.swift index 966c3a9..8a50b36 100644 --- a/Sources/GraphQLTransportWS/InitPayloads.swift +++ b/Sources/GraphQLTransportWS/InitPayloads.swift @@ -1,10 +1,10 @@ // Contains convenient `connection_init` payloads for users of this package /// `connection_init` `payload` that is empty -public struct EmptyInitPayload: Equatable & Codable {} +public struct EmptyInitPayload: Equatable & Codable & Sendable {} /// `connection_init` `payload` that includes an `authToken` field -public struct TokenInitPayload: Equatable & Codable { +public struct TokenInitPayload: Equatable & Codable & Sendable { public let authToken: String public init(authToken: String) { diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index e434269..3ec64ae 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -4,12 +4,18 @@ import GraphQL /// Server implements the server-side portion of the protocol, allowing a few callbacks for customization. /// /// By default, there are no authorization checks -public class Server: @unchecked Sendable { +public class Server< + InitPayload: Equatable & Codable & Sendable, + SubscriptionSequenceType: AsyncSequence & Sendable +>: @unchecked Sendable where + SubscriptionSequenceType.Element == GraphQLResult +{ + // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? let onExecute: (GraphQLRequest) async throws -> GraphQLResult - let onSubscribe: (GraphQLRequest) async throws -> Result, GraphQLErrors> + let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType var auth: (InitPayload) async throws -> Void var onExit: () async throws -> Void = {} @@ -33,7 +39,7 @@ public class Server: @unchecked Sendable { public init( messenger: Messenger, onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult, - onSubscribe: @escaping (GraphQLRequest) async throws -> Result, GraphQLErrors> + onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType ) { self.messenger = messenger self.onExecute = onExecute @@ -160,16 +166,9 @@ public class Server: @unchecked Sendable { } if isStreaming { - do { - let result = try await onSubscribe(graphQLRequest) - let stream: AsyncThrowingStream + subscriptionTasks[id] = Task { do { - stream = try result.get() - } catch { - try await sendError(error, id: id) - return - } - subscriptionTasks[id] = Task { + let stream = try await onSubscribe(graphQLRequest) for try await event in stream { try Task.checkCancellation() do { @@ -179,10 +178,11 @@ public class Server: @unchecked Sendable { throw error } } - try await self.sendComplete(id: id) + } catch { + try await sendError(error, id: id) + throw error } - } catch { - try await sendError(error, id: id) + try await self.sendComplete(id: id) } } else { do { diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index 3efc67a..b967d2f 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -8,7 +8,7 @@ import XCTest class GraphqlTransportWSTests: XCTestCase { var clientMessenger: TestMessenger! var serverMessenger: TestMessenger! - var server: Server! + var server: Server>! var context: TestContext! override func setUp() { @@ -21,7 +21,7 @@ class GraphqlTransportWSTests: XCTestCase { let api = TestAPI() let context = TestContext() - server = Server( + server = .init( messenger: serverMessenger, onExecute: { graphQLRequest in try await api.execute( @@ -33,7 +33,7 @@ class GraphqlTransportWSTests: XCTestCase { try await api.subscribe( request: graphQLRequest.query, context: context - ) + ).get() } ) self.context = context From 31451f9a5537a87f719ce3c89429dcf4bdfeb4a0 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Mon, 22 Sep 2025 16:50:31 -0600 Subject: [PATCH 5/6] fix: Improve subscriptionTask cleanup --- Sources/GraphQLTransportWS/Server.swift | 30 ++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 3ec64ae..74db3d5 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -10,7 +10,6 @@ public class Server< >: @unchecked Sendable where SubscriptionSequenceType.Element == GraphQLResult { - // We keep this weak because we strongly inject this object into the messenger callback weak var messenger: Messenger? @@ -89,13 +88,17 @@ public class Server< try await self.error(.invalidRequestFormat(messageType: .complete)) return } - try await self.onOperationComplete(completeRequest.id) + try await self.onOperationComplete(completeRequest) case .unknown: try await self.error(.invalidType()) } } } + deinit { + subscriptionTasks.values.forEach { $0.cancel() } + } + /// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`. /// Throw from this closure to indicate that authorization has failed. /// - Parameter callback: The callback to assign @@ -171,18 +174,15 @@ public class Server< let stream = try await onSubscribe(graphQLRequest) for try await event in stream { try Task.checkCancellation() - do { - try await self.sendNext(event, id: id) - } catch { - try await self.sendError(error, id: id) - throw error - } + try await self.sendNext(event, id: id) } } catch { try await sendError(error, id: id) + subscriptionTasks.removeValue(forKey: id) throw error } try await self.sendComplete(id: id) + subscriptionTasks.removeValue(forKey: id) } } else { do { @@ -196,6 +196,20 @@ public class Server< } } + private func onOperationComplete(_ completeRequest: CompleteRequest) async throws { + guard initialized else { + try await error(.notInitialized()) + return + } + + let id = completeRequest.id + if let task = subscriptionTasks[id] { + task.cancel() + subscriptionTasks.removeValue(forKey: id) + } + try await onOperationComplete(id) + } + /// Send a `connection_ack` response through the messenger private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws { guard let messenger = messenger else { return } From c470e17fa5d694051c16cc83b5997286a61bd074 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Mon, 22 Sep 2025 23:56:58 -0600 Subject: [PATCH 6/6] test: streaming test uses dynamic wait --- .../GraphQLTransportWSTests.swift | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index b967d2f..e26de4c 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -10,6 +10,7 @@ class GraphqlTransportWSTests: XCTestCase { var serverMessenger: TestMessenger! var server: Server>! var context: TestContext! + var subscribeReady: Bool! = false override func setUp() { // Point the client and server at each other @@ -30,10 +31,12 @@ class GraphqlTransportWSTests: XCTestCase { ) }, onSubscribe: { graphQLRequest in - try await api.subscribe( + let subscription = try await api.subscribe( request: graphQLRequest.query, context: context ).get() + self.subscribeReady = true + return subscription } ) self.context = context @@ -172,8 +175,15 @@ class GraphqlTransportWSTests: XCTestCase { id: id ) - // Short sleep to allow for server to register subscription - usleep(3000) + // Wait until server has registered subscription + var i = 0 + while !self.subscribeReady, i < 50 { + usleep(1000) + i = i + 1 + } + if i == 50 { + XCTFail("Subscription timeout: Took longer than 50ms to set up") + } self.context.publisher.emit(event: "hello \(dataIndex)") }