|
| 1 | +// |
| 2 | +// File.swift |
| 3 | +// |
| 4 | +// |
| 5 | +// Created by Måns Severin on 2024-05-19. |
| 6 | +// |
| 7 | + |
| 8 | +import Foundation |
| 9 | + |
| 10 | +import NIOCore |
| 11 | +import NIOPosix |
| 12 | +import NIOHTTP1 |
| 13 | +import NIOSSL |
| 14 | +import NIOWebSocket |
| 15 | + |
| 16 | +public final actor WebSocketTransport { |
| 17 | + public enum WebSocketTransportError: Error { |
| 18 | + case invalidURI(String) |
| 19 | + } |
| 20 | + |
| 21 | + public enum WebSocketEvent { |
| 22 | + case close |
| 23 | + case disconnected |
| 24 | + case failed |
| 25 | + case message |
| 26 | + case open |
| 27 | + } |
| 28 | + |
| 29 | + public typealias WebSocketEventHandler = (WebSocketEvent, WebSocketTransport) -> Void |
| 30 | + |
| 31 | + private static let DEFAULT_RETRY_OPTIONS = [ |
| 32 | + "retries" : 10, |
| 33 | + "factor" : 2, |
| 34 | + "minTimeout" : 1 * 1000, |
| 35 | + "maxTimeout" : 8 * 1000 |
| 36 | + ] |
| 37 | + |
| 38 | + private enum State { |
| 39 | + case idle |
| 40 | + case connecting(Task<Void, Error>) |
| 41 | + case connected(Channel) |
| 42 | + } |
| 43 | + |
| 44 | + private let uri: URL |
| 45 | + private let group: MultiThreadedEventLoopGroup |
| 46 | + |
| 47 | + private var state: State = .idle |
| 48 | + private var eventHandlers: [WebSocketEvent: [WebSocketEventHandler]] = [:] |
| 49 | + |
| 50 | + public init(uri: URL, group: MultiThreadedEventLoopGroup) { |
| 51 | + self.uri = uri |
| 52 | + self.group = group |
| 53 | + } |
| 54 | + |
| 55 | + public func connect() { |
| 56 | + guard case .idle = state else { |
| 57 | + return |
| 58 | + } |
| 59 | + |
| 60 | + state = .connecting(.init { |
| 61 | + |
| 62 | + }) |
| 63 | + } |
| 64 | + |
| 65 | + public func disconnect() { |
| 66 | + |
| 67 | + } |
| 68 | + |
| 69 | + public func on(_ event: WebSocketEvent, handler: @escaping WebSocketEventHandler) -> Self { |
| 70 | + return self |
| 71 | + } |
| 72 | + |
| 73 | + private func connect(retriesLeft: Int = 10) async throws { |
| 74 | + guard let components = URLComponents(url: uri, resolvingAgainstBaseURL: true) else { |
| 75 | + throw WebSocketTransportError.invalidURI("invalid uri") |
| 76 | + } |
| 77 | + |
| 78 | + guard let host = components.host else { |
| 79 | + throw WebSocketTransportError.invalidURI("invalid uri: no host") |
| 80 | + } |
| 81 | + |
| 82 | + guard let scheme = components.scheme else { |
| 83 | + throw WebSocketTransportError.invalidURI("invalid uri: no scheme") |
| 84 | + } |
| 85 | + |
| 86 | + guard scheme == "ws" || scheme == "wss" else { |
| 87 | + throw WebSocketTransportError.invalidURI("invalid uri: scheme is not ws or wss") |
| 88 | + } |
| 89 | + |
| 90 | + let port = components.port ?? (scheme == "wss" ? 443 : 80) |
| 91 | + let path = components.path |
| 92 | + let bootstrap = WebSocketBootstrap(scheme: scheme, host: host, port: port, path: path, group: group) |
| 93 | + let channel = try await bootstrap.connect() |
| 94 | + } |
| 95 | + |
| 96 | +// /// This method handles the upgrade result. |
| 97 | +// private func handleUpgradeResult(_ upgradeResult: EventLoopFuture<UpgradeResult>) async throws { |
| 98 | +// switch try await upgradeResult.get() { |
| 99 | +// case .websocket(let websocketChannel): |
| 100 | +// print("Handling websocket connection") |
| 101 | +// try await self.handleWebsocketChannel(websocketChannel) |
| 102 | +// print("Done handling websocket connection") |
| 103 | +// case .failed: |
| 104 | +// // The upgrade to websocket did not succeed. We are just exiting in this case. |
| 105 | +// print("Upgrade declined") |
| 106 | +// } |
| 107 | +// } |
| 108 | +// |
| 109 | +// private func handleWebsocketChannel(_ channel: NIOAsyncChannel<WebSocketFrame, WebSocketFrame>) async throws { |
| 110 | +// // We are sending a ping frame and then |
| 111 | +// // start to handle all inbound frames. |
| 112 | +// |
| 113 | +// let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) |
| 114 | +// try await channel.executeThenClose { inbound, outbound in |
| 115 | +// try await outbound.write(pingFrame) |
| 116 | +// |
| 117 | +// for try await frame in inbound { |
| 118 | +// switch frame.opcode { |
| 119 | +// case .pong: |
| 120 | +// print("Received pong: \(String(buffer: frame.data))") |
| 121 | +// |
| 122 | +// case .text: |
| 123 | +// print("Received: \(String(buffer: frame.data))") |
| 124 | +// |
| 125 | +// case .connectionClose: |
| 126 | +// // Handle a received close frame. We're just going to close by returning from this method. |
| 127 | +// print("Received Close instruction from server") |
| 128 | +// return |
| 129 | +// case .binary, .continuation, .ping: |
| 130 | +// // We ignore these frames. |
| 131 | +// break |
| 132 | +// default: |
| 133 | +// // Unknown frames are errors. |
| 134 | +// return |
| 135 | +// } |
| 136 | +// } |
| 137 | +// } |
| 138 | +// } |
| 139 | +} |
| 140 | + |
| 141 | +struct WebSocketBootstrap { |
| 142 | + enum WebSocketBootstrapError: Error { |
| 143 | + case upgradeFailed |
| 144 | + } |
| 145 | + |
| 146 | + let scheme: String |
| 147 | + let host: String |
| 148 | + let port: Int |
| 149 | + let path: String |
| 150 | + let group: EventLoopGroup |
| 151 | + |
| 152 | + private enum UpgradeResult { |
| 153 | + case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>) |
| 154 | + case failed |
| 155 | + } |
| 156 | + |
| 157 | + func connect() async throws -> NIOAsyncChannel<WebSocketFrame, WebSocketFrame> { |
| 158 | + let initializer: @Sendable (Channel) -> EventLoopFuture<EventLoopFuture<UpgradeResult>> = { channel in |
| 159 | + channel.eventLoop.makeCompletedFuture { |
| 160 | + let upgrader = NIOTypedWebSocketClientUpgrader<UpgradeResult>(upgradePipelineHandler: { (channel, _) in |
| 161 | + channel.eventLoop.makeCompletedFuture { |
| 162 | + return UpgradeResult.websocket(try NIOAsyncChannel<WebSocketFrame, WebSocketFrame>(wrappingChannelSynchronously: channel)) |
| 163 | + } |
| 164 | + }) |
| 165 | + |
| 166 | + var headers = HTTPHeaders() |
| 167 | + |
| 168 | + headers.add(name: "Host", value: host) |
| 169 | + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") |
| 170 | + headers.add(name: "Content-Length", value: "0") |
| 171 | + headers.add(name: "Sec-WebSocket-Protocol", value: "protoo") |
| 172 | + |
| 173 | + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: path, headers: headers) |
| 174 | + |
| 175 | + let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( |
| 176 | + upgradeRequestHead: requestHead, |
| 177 | + upgraders: [upgrader], |
| 178 | + notUpgradingCompletionHandler: { channel in |
| 179 | + channel.eventLoop.makeCompletedFuture { |
| 180 | + return UpgradeResult.failed |
| 181 | + } |
| 182 | + } |
| 183 | + ) |
| 184 | + |
| 185 | + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( |
| 186 | + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) |
| 187 | + ) |
| 188 | + |
| 189 | + return negotiationResultFuture |
| 190 | + } |
| 191 | + } |
| 192 | + |
| 193 | + let upgradeResult: EventLoopFuture<UpgradeResult> |
| 194 | + |
| 195 | + if scheme == "wss" { |
| 196 | + let configuration = TLSConfiguration.makeClientConfiguration() |
| 197 | + let sslContext = try NIOSSLContext(configuration: configuration) |
| 198 | + let bootstrap = try NIOClientTCPBootstrap(ClientBootstrap(group: group), tls: NIOSSLClientTLSProvider(context: sslContext, serverHostname: host)) |
| 199 | + |
| 200 | + upgradeResult = try await bootstrap |
| 201 | + .enableTLS() |
| 202 | + .connect(host: host, port: port, channelInitializer: initializer) |
| 203 | + } else { |
| 204 | + let bootstrap = ClientBootstrap(group: group) |
| 205 | + |
| 206 | + upgradeResult = try await bootstrap.connect(host: host, port: port, channelInitializer: initializer) |
| 207 | + } |
| 208 | + |
| 209 | + guard case let .websocket(channel) = try await upgradeResult.get() else { |
| 210 | + throw WebSocketBootstrapError.upgradeFailed |
| 211 | + } |
| 212 | + |
| 213 | + let aggregator = NIOWebSocketFrameAggregator(minNonFinalFragmentSize: 256, |
| 214 | + maxAccumulatedFrameCount: 100, |
| 215 | + maxAccumulatedFrameSize: 10 * 1024 * 1024) |
| 216 | + |
| 217 | + try await channel.channel.pipeline.addHandler(aggregator).get() |
| 218 | + |
| 219 | + return channel |
| 220 | + } |
| 221 | +} |
| 222 | + |
| 223 | +extension NIOClientTCPBootstrap { |
| 224 | + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) |
| 225 | + public func connect<Output: Sendable>( |
| 226 | + host: String, |
| 227 | + port: Int, |
| 228 | + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output> |
| 229 | + ) async throws -> Output { |
| 230 | + return try await (self.underlyingBootstrap as! ClientBootstrap).connect(host: host, port: port, channelInitializer: channelInitializer) |
| 231 | + } |
| 232 | +} |
0 commit comments