diff --git a/build.gradle.kts b/build.gradle.kts index 23b99c58..74bba141 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ plugins { allprojects { group = "in.dragonbra" - version = "1.6.0-SNAPSHOT" + version = "1.6.0" } repositories { @@ -122,6 +122,7 @@ dependencies { implementation(libs.okHttp) implementation(libs.xz) implementation(libs.protobuf.java) + implementation(libs.bundles.ktor) testImplementation(libs.bundles.testing) } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index e9b0bae4..6310548f 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -44,6 +44,9 @@ protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "p protobuf-protoc = { module = "com.google.protobuf:protoc", version.ref = "protobuf" } qrCode = { module = "pro.leaco.qrcode:console-qrcode", version.ref = "qrCode" } xz = { module = "org.tukaani:xz", version.ref = "xz" } +ktor-client-core = { module = "io.ktor:ktor-client-core", version = "3.0.3" } +ktor-client-cio = { module = "io.ktor:ktor-client-cio", version = "3.0.3" } +ktor-client-websocket = { module = "io.ktor:ktor-client-websockets", version = "3.0.3" } test-commons-codec = { module = "commons-codec:commons-codec", version.ref = "commonsCodec" } test-jupiter-api = { module = "org.junit.jupiter:junit-jupiter-api", version.ref = "junit5" } @@ -71,3 +74,9 @@ testing = [ "test-mockito-core", "test-mockito-jupiter", ] + +ktor = [ + "ktor-client-core", + "ktor-client-cio", + "ktor-client-websocket", +] diff --git a/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketCMClient.kt b/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketCMClient.kt deleted file mode 100644 index c30cf2fb..00000000 --- a/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketCMClient.kt +++ /dev/null @@ -1,100 +0,0 @@ -package `in`.dragonbra.javasteam.networking.steam3 - -import okhttp3.OkHttpClient -import okhttp3.Request -import okhttp3.Response -import okhttp3.WebSocket -import okhttp3.WebSocketListener -import okio.ByteString -import java.net.URI -import java.util.concurrent.TimeUnit - -class WebSocketCMClient( - timeout: Int, - private val serverUrl: URI, - private val listener: WSListener, -) : WebSocketListener() { - - companion object { - // private val logger = LogManager.getLogger(WebSocketCMClient::class.java) - } - - private val client = OkHttpClient.Builder() - .readTimeout(timeout.toLong(), TimeUnit.MILLISECONDS) - .build() - - private var webSocket: WebSocket? = null - - /** - * Invoked when a web socket has been accepted by the remote peer and may begin transmitting - * messages. - */ - override fun onOpen(webSocket: WebSocket, response: Response) { - listener.onOpen(response) - response.close() - } - - /** Invoked when a text (type `0x1`) message has been received. */ - override fun onMessage(webSocket: WebSocket, text: String) { - listener.onTextData(text) - } - - /** Invoked when a binary (type `0x2`) message has been received. */ - override fun onMessage(webSocket: WebSocket, bytes: ByteString) { - listener.onData(bytes.toByteArray()) - } - - /** - * Invoked when the remote peer has indicated that no more incoming messages will be transmitted. - */ - override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { - listener.onClosing(code, reason) - } - - /** - * Invoked when both peers have indicated that no more messages will be transmitted and the - * connection has been successfully released. No further calls to this listener will be made. - */ - override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { - listener.onClose(code, reason) - this.webSocket = null - } - - /** - * Invoked when a web socket has been closed due to an error reading from or writing to the - * network. Both outgoing and incoming messages may have been lost. No further calls to this - * listener will be made. - */ - override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { - listener.onError(t) - response?.close() - this.webSocket = null - } - - fun connect() { - val request = Request.Builder().url(serverUrl.toString()).build() - webSocket = client.newWebSocket(request, this) - } - - fun send(data: ByteArray) { - webSocket?.send(ByteString.of(*data)) - } - - fun close() { - webSocket?.close(1000, null) - - // Shutdown the okhttp client to prevent hanging. - client.dispatcher.executorService.shutdown() - client.connectionPool.evictAll() - client.cache?.close() - } - - interface WSListener { - fun onTextData(data: String) - fun onData(data: ByteArray) - fun onClose(code: Int, reason: String) - fun onClosing(code: Int, reason: String) - fun onError(t: Throwable) - fun onOpen(response: Response) - } -} diff --git a/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketConnection.kt b/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketConnection.kt index 9dce1967..0c6f0b01 100644 --- a/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketConnection.kt +++ b/src/main/java/in/dragonbra/javasteam/networking/steam3/WebSocketConnection.kt @@ -1,105 +1,175 @@ package `in`.dragonbra.javasteam.networking.steam3 import `in`.dragonbra.javasteam.util.log.LogManager -import `in`.dragonbra.javasteam.util.log.Logger -import okhttp3.Response +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.websocket.WebSockets +import io.ktor.client.plugins.websocket.pingInterval +import io.ktor.client.plugins.websocket.webSocketSession +import io.ktor.http.URLProtocol +import io.ktor.http.path +import io.ktor.websocket.Frame +import io.ktor.websocket.WebSocketSession +import io.ktor.websocket.close +import io.ktor.websocket.readBytes +import io.ktor.websocket.readText +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancelChildren +import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch import java.net.InetAddress import java.net.InetSocketAddress -import java.net.URI -import java.util.concurrent.atomic.AtomicReference +import kotlin.coroutines.CoroutineContext +import kotlin.time.DurationUnit +import kotlin.time.toDuration class WebSocketConnection : Connection(), - WebSocketCMClient.WSListener { + CoroutineScope { companion object { - private val logger: Logger = LogManager.getLogger(WebSocketConnection::class.java) - - private fun constructUri(address: InetSocketAddress): URI = - URI.create("wss://${address.hostString}:${address.port}/cmsocket/") + private val logger = LogManager.getLogger(WebSocketConnection::class.java) } - private val client = AtomicReference(null) + private val job: Job = SupervisorJob() - private var socketEndPoint: InetSocketAddress? = null + private var client: HttpClient? = null - override fun connect(endPoint: InetSocketAddress, timeout: Int) { - logger.debug("Connecting to $endPoint...") + private var session: WebSocketSession? = null - val serverUri = constructUri(endPoint) - val newClient = WebSocketCMClient(timeout, serverUri, this) - val oldClient = client.getAndSet(newClient) + private var endpoint: InetSocketAddress? = null - oldClient?.let { oldClient -> - logger.debug("Attempted to connect while already connected. Closing old connection...") - oldClient.close() - onDisconnected(false) - } + private var lastFrameTime = System.currentTimeMillis() - socketEndPoint = endPoint + override val coroutineContext: CoroutineContext = Dispatchers.IO + job - newClient.connect() + override fun connect(endPoint: InetSocketAddress, timeout: Int) { + launch { + logger.debug("Trying connection to ${endPoint.hostName}:${endPoint.port}") + + try { + endpoint = endPoint + + client = HttpClient(CIO) { + install(WebSockets) { + pingInterval = timeout.toDuration(DurationUnit.SECONDS) + } + } + + val session = client?.webSocketSession { + url { + host = endPoint.hostName + port = endPoint.port + protocol = URLProtocol.WSS + path("cmsocket/") + } + } + + this@WebSocketConnection.session = session + + startConnectionMonitoring() + + launch { + try { + session?.incoming?.consumeEach { frame -> + when (frame) { + is Frame.Binary -> { + // logger.debug("on Binary ${frame.data.size}") + lastFrameTime = System.currentTimeMillis() + onNetMsgReceived(NetMsgEventArgs(frame.readBytes(), currentEndPoint)) + } + + is Frame.Close -> disconnect(false) + is Frame.Ping -> logger.debug("Received pong") + is Frame.Pong -> logger.debug("Received pong") + is Frame.Text -> logger.debug("Received plain text ${frame.readText()}") + } + } + } catch (e: Exception) { + logger.error("An error occurred while receiving data", e) + disconnect(false) + } + } + + logger.debug("Connected to ${endPoint.hostName}:${endPoint.port}") + onConnected() + } catch (e: Exception) { + logger.error("An error occurred setting up the web socket client", e) + disconnect(false) + } + } } override fun disconnect(userInitiated: Boolean) { - disconnectCore(userInitiated) + logger.debug("Disconnect called: $userInitiated") + launch { + try { + session?.close() + client?.close() + } finally { + session = null + client = null + + job.cancelChildren() + } + } + + onDisconnected(userInitiated) } override fun send(data: ByteArray) { - try { - client.get()?.send(data) - } catch (e: Exception) { - logger.debug("Exception while sending data", e) - disconnectCore(false) + launch { + try { + val frame = Frame.Binary(true, data) + session?.send(frame) + } catch (e: Exception) { + logger.error("An error occurred while sending data", e) + disconnect(false) + } } } - override fun getLocalIP(): InetAddress? = InetAddress.getByAddress(byteArrayOf(0, 0, 0, 0)) + override fun getLocalIP(): InetAddress = InetAddress.getLocalHost() - override fun getCurrentEndPoint(): InetSocketAddress? = socketEndPoint + override fun getCurrentEndPoint(): InetSocketAddress? = endpoint override fun getProtocolTypes(): ProtocolTypes = ProtocolTypes.WEB_SOCKET - private fun disconnectCore(userInitiated: Boolean) { - logger.debug("User initiated disconnection: $userInitiated") - - val oldClient = client.getAndSet(null) - oldClient?.close() - - onDisconnected(userInitiated) + /** + * Rudimentary watchdog + */ + private fun startConnectionMonitoring() { + launch { + while (isActive) { + if (client?.isActive == false || session?.isActive == false) { + logger.error("Client or Session is no longer active") + disconnect(userInitiated = false) + } - socketEndPoint = null - } + val timeSinceLastFrame = System.currentTimeMillis() - lastFrameTime - override fun onTextData(data: String) { - // Ignore string messages - logger.debug("Got string message: $data") - } - - override fun onData(data: ByteArray) { - if (data.isNotEmpty()) { - onNetMsgReceived(NetMsgEventArgs(data, getCurrentEndPoint())) - } - } + // logger.debug("Watchdog status: $timeSinceLastFrame") + when { + timeSinceLastFrame > 30000 -> { + logger.error("Watchdog: No response for 30 seconds. Disconnecting from steam") + disconnect(userInitiated = false) + break + } - override fun onClose(code: Int, reason: String) { - logger.debug("Connection closed") - } + timeSinceLastFrame > 25000 -> logger.debug("Watchdog: No response for 25 seconds") - override fun onClosing(code: Int, reason: String) { - logger.debug("Closing connection: $code, reason: ${reason.ifEmpty { "No reason given" }}") - // Steam can close a connection if there is nothing else it wants to send. - // For example: AccountLoginDeniedNeedTwoFactor, InvalidPassword, etc. - disconnectCore(code == 1000) - } + timeSinceLastFrame > 20000 -> logger.debug("Watchdog: No response for 20 seconds") - override fun onError(t: Throwable) { - logger.error("Error in websocket", t) - disconnectCore(false) - } + timeSinceLastFrame > 15000 -> logger.debug("Watchdog: No response for 15 seconds") + } - override fun onOpen(response: Response) { - logger.debug("WebSocket connected to $socketEndPoint using TLS: ${response.handshake?.tlsVersion}") - onConnected() + delay(5000) + } + } } } diff --git a/src/main/java/in/dragonbra/javasteam/steam/CMClient.java b/src/main/java/in/dragonbra/javasteam/steam/CMClient.java index 1491de56..aa778935 100644 --- a/src/main/java/in/dragonbra/javasteam/steam/CMClient.java +++ b/src/main/java/in/dragonbra/javasteam/steam/CMClient.java @@ -117,7 +117,12 @@ public CMClient(SteamConfiguration configuration) { this.configuration = configuration; - heartBeatFunc = new ScheduledFunction(() -> send(new ClientMsgProtobuf(CMsgClientHeartBeat.class, EMsg.ClientHeartBeat)), 5000); + heartBeatFunc = new ScheduledFunction(() -> { + var heartbeat = new ClientMsgProtobuf( + CMsgClientHeartBeat.class, EMsg.ClientHeartBeat); + heartbeat.getBody().setSendReply(true); // Ping Pong + send(heartbeat); + }, 5000); } /** diff --git a/src/main/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerList.kt b/src/main/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerList.kt index b2e1129a..9c8df9d7 100644 --- a/src/main/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerList.kt +++ b/src/main/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerList.kt @@ -193,10 +193,15 @@ class SmartCMServerList(private val configuration: SteamConfiguration) { } } - fun tryMark(endPoint: InetSocketAddress, protocolTypes: ProtocolTypes, quality: ServerQuality): Boolean = + fun tryMark(endPoint: InetSocketAddress?, protocolTypes: ProtocolTypes, quality: ServerQuality): Boolean = tryMark(endPoint, EnumSet.of(protocolTypes), quality) - fun tryMark(endPoint: InetSocketAddress, protocolTypes: EnumSet, quality: ServerQuality): Boolean { + fun tryMark(endPoint: InetSocketAddress?, protocolTypes: EnumSet, quality: ServerQuality): Boolean { + if (endPoint == null) { + logger.error("Couldn't mark an endpoint ${quality.name}, skipping it") + return false + } + val serverInfos: List if (quality == ServerQuality.GOOD) { diff --git a/src/test/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerListTest.java b/src/test/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerListTest.java index fd83daca..ef0e04e3 100644 --- a/src/test/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerListTest.java +++ b/src/test/java/in/dragonbra/javasteam/steam/discovery/SmartCMServerListTest.java @@ -285,4 +285,11 @@ var record = ServerRecord.createSocketServer(new InetSocketAddress(InetAddress.g var marked = serverList.tryMark(new InetSocketAddress(InetAddress.getLoopbackAddress(), 27016), record.getProtocolTypes(), ServerQuality.GOOD); Assertions.assertFalse(marked); } + + @Test + public void testNullConnection_ShouldReturnFalse() { + var result = serverList.tryMark(null, ProtocolTypes.WEB_SOCKET, ServerQuality.BAD); + + Assertions.assertFalse(result); + } }