assume arbitrary short messages are Keep Alives

This commit is contained in:
Vadim Brilyantov
2018-05-16 21:26:05 +03:00
parent 30abd783b2
commit e87cb06187
6 changed files with 110 additions and 131 deletions

View File

@@ -7,6 +7,9 @@
package org.jetbrains.kotlin.daemon.common.experimental
import kotlinx.coroutines.experimental.async
import kotlinx.coroutines.experimental.delay
import kotlinx.coroutines.experimental.newSingleThreadContext
import org.jetbrains.kotlin.cli.common.repl.ReplCheckResult
import org.jetbrains.kotlin.cli.common.repl.ReplCodeLine
import org.jetbrains.kotlin.cli.common.repl.ReplCompileResult
@@ -14,6 +17,7 @@ import org.jetbrains.kotlin.daemon.common.*
import org.jetbrains.kotlin.daemon.common.CompileService.CallResult
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.*
import java.io.File
import java.util.concurrent.TimeUnit
import java.util.logging.Logger
interface CompileServiceClientSide : CompileServiceAsync, Client<CompileServiceServerSide> {
@@ -26,7 +30,19 @@ class CompileServiceClientSideImpl(
val serverHost: String,
val serverFile: File
) : CompileServiceClientSide,
Client<CompileServiceServerSide> by object : DefaultAuthorizableClient<CompileServiceServerSide>(serverPort, serverHost) {
Client<CompileServiceServerSide> by object : DefaultAuthorizableClient<CompileServiceServerSide>(
serverPort,
serverHost
) {
private fun nowMillieconds() = System.currentTimeMillis()
@Volatile
private var lastUsedMilliSeconds: Long = nowMillieconds()
private fun deltaTime() = nowMillieconds() - lastUsedMilliSeconds
private fun keepAliveSuccess() = deltaTime() < KEEPALIVE_PERIOD
override suspend fun authorizeOnServer(serverOutputChannel: ByteWriteChannelWrapper): Boolean =
runWithTimeout {
@@ -40,6 +56,33 @@ class CompileServiceClientSideImpl(
return trySendHandshakeMessage(output, log) && tryAcquireHandshakeMessage(input, log)
}
override suspend fun startKeepAlives() {
val keepAliveMessage = Server.KeepAliveMessage<CompileServiceServerSide>()
async(newSingleThreadContext("keepAliveThread")) {
delay(KEEPALIVE_PERIOD * 4)
while (true) {
delay(KEEPALIVE_PERIOD)
// println("[$this] KEEPALIVE_PERIOD")
while (keepAliveSuccess()) {
// println("[$this] remained ${KEEPALIVE_PERIOD - deltaTime()}")
delay(KEEPALIVE_PERIOD - deltaTime())
}
runWithTimeout(timeout = KEEPALIVE_PERIOD / 2) {
// println("[$this] sent keepalive")
val id = sendMessage(keepAliveMessage)
readMessage<Server.KeepAliveAcknowledgement<*>>(id)
} ?: if (!keepAliveSuccess()) readActor.send(StopAllRequests()).also {
// println("[$this] got keepalive")
}
}
}
}
override suspend fun delayKeepAlives() {
// println("[$this] delayKeepAlives")
lastUsedMilliSeconds = nowMillieconds()
}
} {
val log = Logger.getLogger("CompileServiceClientSideImpl")

View File

@@ -1,9 +1,9 @@
package org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure
import io.ktor.network.sockets.Socket
import kotlinx.coroutines.experimental.*
import kotlinx.coroutines.experimental.CompletableDeferred
import kotlinx.coroutines.experimental.channels.*
import org.jetbrains.kotlin.daemon.common.experimental.KEEPALIVE_PERIOD
import kotlinx.coroutines.experimental.runBlocking
import org.jetbrains.kotlin.daemon.common.experimental.LoopbackNetworkInterface
import sun.net.ConnectionResetException
import java.beans.Transient
@@ -11,7 +11,6 @@ import java.io.IOException
import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.io.Serializable
import java.util.concurrent.atomic.AtomicBoolean
import java.util.logging.Logger
@@ -51,38 +50,27 @@ abstract class DefaultAuthorizableClient<ServerType : ServerBase>(
abstract suspend fun authorizeOnServer(serverOutputChannel: ByteWriteChannelWrapper): Boolean
abstract suspend fun clientHandshake(input: ByteReadChannelWrapper, output: ByteWriteChannelWrapper, log: Logger): Boolean
abstract suspend fun startKeepAlives()
abstract suspend fun delayKeepAlives()
override fun close() {
// try {
// runBlockingWithTimeout {
// sendMessage(Server.EndConnectionMessage())
// }
// } catch (e: Throwable) {
// log.info_and_print(e.message)
// } finally {
// socket?.close()
// }
socket?.close()
}
public class MessageReply<T : Any>(val messageId: Int, val reply: T?) : Serializable
class MessageReply<T : Any>(val messageId: Int, val reply: T?) : Serializable
private interface ReadActorQuery
private data class ExpectReplyQuery(val messageId: Int, val result: CompletableDeferred<MessageReply<*>>) : ReadActorQuery
private class ReceiveReplyQuery(val reply: MessageReply<*>) : ReadActorQuery
protected interface ReadActorQuery
protected data class ExpectReplyQuery(val messageId: Int, val result: CompletableDeferred<MessageReply<*>>) : ReadActorQuery
protected class ReceiveReplyQuery(val reply: MessageReply<*>) : ReadActorQuery
private interface WriteActorQuery
private data class SendNoreplyMessageQuery(val message: Server.AnyMessage<*>) : WriteActorQuery
private data class SendMessageQuery(val message: Server.AnyMessage<*>, val messageId: CompletableDeferred<Any>) : WriteActorQuery
protected interface WriteActorQuery
protected data class SendNoreplyMessageQuery(val message: Server.AnyMessage<*>) : WriteActorQuery
protected data class SendMessageQuery(val message: Server.AnyMessage<*>, val messageId: CompletableDeferred<Any>) : WriteActorQuery
private class StopAllRequests : ReadActorQuery, WriteActorQuery
// @kotlin.jvm.Transient
// private lateinit var intermediateActor: SendChannel<ReceiveReplyQuery>
protected class StopAllRequests : ReadActorQuery, WriteActorQuery
@kotlin.jvm.Transient
private lateinit var readActor: SendChannel<ReadActorQuery>
protected lateinit var readActor: SendChannel<ReadActorQuery>
@kotlin.jvm.Transient
private lateinit var writeActor: SendChannel<WriteActorQuery>
@@ -158,7 +146,7 @@ abstract class DefaultAuthorizableClient<ServerType : ServerBase>(
val objectReaderActor = actor<NextObjectQuery>(capacity = Channel.UNLIMITED) {
consumeEach {
try {
val reply = input.nextObject().await()
val reply = input.nextObject()
if (reply is Server.ServerDownMessage<*>) {
throw IOException("connection closed by server")
} else if (reply !is MessageReply<*>) {
@@ -207,10 +195,11 @@ abstract class DefaultAuthorizableClient<ServerType : ServerBase>(
log.info_and_print("[${log.name}] : intermediateActor.send(ReceiveReplyQuery())")
objectReaderActor.send(nextObjectQuery)
}
delayKeepAlives()
}
is StopAllRequests -> {
writeActor.send(StopAllRequests())
broadcastIOException(IOException("KeepAlive failed"))
writeActor.send(StopAllRequests())
}
}
}
@@ -242,65 +231,18 @@ abstract class DefaultAuthorizableClient<ServerType : ServerBase>(
log.info_and_print("failed authorization($serverPort)")
throw ConnectionResetException("failed to establish connection with server (authorization failed)")
}
startKeepAlives()
}
}
startKeepAlives()
private fun startKeepAlives() {
val keepAliveMessage = Server.KeepAliveMessage<ServerType>()
var serverAlive = AtomicBoolean(true)
val stopKeepAlives: suspend () -> Unit = {
sendMessage(Server.EndConnectionMessage())
readActor.send(StopAllRequests())
serverAlive.set(false)
}
async(newSingleThreadContext("keep_alive")) {
// println("[$serverPort] _____START______")
delay(KEEPALIVE_PERIOD * 4)
while (serverAlive.get()) {
runWithTimeout(timeout = KEEPALIVE_PERIOD / 2) {
val id = sendMessage(keepAliveMessage)
try {
readMessage<Server.KeepAliveAcknowledgement<ServerType>>(id).also {
// println("[$serverPort] received KeepAlive => OK")
}
} catch (e: IOException) {
// println("[$serverPort] IOException => StopAllRequests")
stopKeepAlives()
}
} ?: stopKeepAlives().let {
// println("[$serverPort] timeout => StopAllRequests")
}
delay(KEEPALIVE_PERIOD)
}
// println("[$serverPort] _____FINISH______")
}
// timer.schedule(delay = KEEPALIVE_PERIOD * 6, period = KEEPALIVE_PERIOD) {
// if (!checkServerAliveness(
// keepAliveMessage,
// keepAliveContext,
// stopKeepAlives = {
// async(keepAliveContext) {
// sendMessage(Server.EndConnectionMessage())
// readActor.send(StopAllRequests())
// timer.cancel()
// timer.purge()
// }
// }
// )) {
// println("[$serverPort] timer.cancel()!")
// }
// }
}
@Throws(ClassNotFoundException::class, IOException::class)
private fun readObject(aInputStream: ObjectInputStream) {
aInputStream.defaultReadObject()
println("connecting...")
runBlocking { connectToServer() }
println("connectED")
}
@Throws(IOException::class)
@@ -316,9 +258,12 @@ class DefaultClient<ServerType : ServerBase>(
) : DefaultAuthorizableClient<ServerType>(serverPort, serverHost) {
override suspend fun clientHandshake(input: ByteReadChannelWrapper, output: ByteWriteChannelWrapper, log: Logger) = true
override suspend fun authorizeOnServer(output: ByteWriteChannelWrapper): Boolean = true
override suspend fun startKeepAlives() {}
override suspend fun delayKeepAlives() {}
}
class DefaultClientRMIWrapper<ServerType : ServerBase> : Client<ServerType> {
override suspend fun connectToServer() {}
override suspend fun sendMessage(msg: Server.AnyMessage<out ServerType>) =
throw UnsupportedOperationException("sendMessage is not supported for RMI wrappers")

View File

@@ -2,10 +2,7 @@ package org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure
import io.ktor.network.sockets.Socket
import kotlinx.coroutines.experimental.*
import org.jetbrains.kotlin.daemon.common.experimental.AUTH_TIMEOUT_IN_MILLISECONDS
import org.jetbrains.kotlin.daemon.common.experimental.FIRST_HANDSHAKE_BYTE_TOKEN
import org.jetbrains.kotlin.daemon.common.experimental.ServerSocketWrapper
import org.jetbrains.kotlin.daemon.common.experimental.log
import org.jetbrains.kotlin.daemon.common.experimental.*
import java.io.Serializable
import java.util.concurrent.TimeUnit
import java.util.logging.Logger
@@ -50,11 +47,7 @@ interface Server<out T : ServerBase> : ServerBase {
else -> Server.State.ERROR
}
fun attachClient(
client: Socket,
readerPool: ThreadPoolDispatcher,
keepAlivePool: ThreadPoolDispatcher
): Deferred<State> = async(readerPool) {
fun attachClient(client: Socket): Deferred<State> = async {
val (input, output) = client.openIO(log)
if (!serverHandshake(input, output, log)) {
log.info_and_print("failed to establish connection with client (handshake failed)")
@@ -72,21 +65,19 @@ interface Server<out T : ServerBase> : ServerBase {
loop@
while (true) {
log.info_and_print(" reading message from ($client)")
val message = input.nextObject().await()
val message = input.nextObject()
when (message) {
is Server.ServerDownMessage<*> -> {
downClient(client)
break@loop
}
is Server.KeepAliveMessage<*> -> Server.State.WORKING.also {
async(keepAlivePool) {
output.writeObject(
DefaultAuthorizableClient.MessageReply(
message.messageId ?: -1,
keepAliveAcknowledgement
)
output.writeObject(
DefaultAuthorizableClient.MessageReply(
message.messageId!!,
keepAliveAcknowledgement
)
}
)
}
!is Server.AnyMessage<*> -> {
log.info_and_print("contrafact message")
@@ -123,11 +114,11 @@ interface Server<out T : ServerBase> : ServerBase {
}
abstract class Message<ServerType : ServerBase> : AnyMessage<ServerType>() {
fun process(server: ServerType, output: ByteWriteChannelWrapper) = async(CommonPool) {
fun process(server: ServerType, output: ByteWriteChannelWrapper) = async {
log.info("$server starts processing ${this@Message}")
processImpl(server, {
log.info("$server finished processing ${this@Message}, sending output")
async(CommonPool) {
async {
log.info("$server starts sending ${this@Message} to output")
output.writeObject(DefaultAuthorizableClient.MessageReply(messageId ?: -1, it))
log.info("$server finished sending ${this@Message} to output")
@@ -140,10 +131,10 @@ interface Server<out T : ServerBase> : ServerBase {
class EndConnectionMessage<ServerType : ServerBase> : AnyMessage<ServerType>()
class KeepAliveMessage<ServerType : ServerBase> : AnyMessage<ServerType>()
class KeepAliveAcknowledgement<ServerType : ServerBase> : AnyMessage<ServerType>()
class KeepAliveMessage<ServerType : ServerBase> : AnyMessage<ServerType>()
class ServerDownMessage<ServerType : ServerBase> : AnyMessage<ServerType>()
data class ClientInfo(val socket: Socket, val input: ByteReadChannelWrapper, val output: ByteWriteChannelWrapper)
@@ -153,16 +144,14 @@ interface Server<out T : ServerBase> : ServerBase {
fun runServer(): Deferred<Unit> {
log.info_and_print("binding to address(${serverSocketWithPort.port})")
val serverSocket = serverSocketWithPort.socket
val readerPool = newFixedThreadPoolContext(nThreads = 4, name = "readerPool")
val keepAlivePool = newFixedThreadPoolContext(nThreads = 10, name = "keepAlivePool")
return async(CommonPool) {
return async {
serverSocket.use {
log.info_and_print("accepting clientSocket...")
while (true) {
val client = serverSocket.accept()
log.info_and_print("client accepted! (${client.remoteAddress})")
async(CommonPool) {
val state = attachClient(client, readerPool, keepAlivePool).await()
async {
val state = attachClient(client).await()
log.info_and_print("finished ($client) with state : $state")
when (state) {
Server.State.CLOSED, State.UNVERIFIED -> {
@@ -190,6 +179,7 @@ interface Server<out T : ServerBase> : ServerBase {
socket.close()
}
clients.clear()
serverSocketWithPort.socket.close()
}
private fun downClient(client: Socket) {
@@ -199,6 +189,7 @@ interface Server<out T : ServerBase> : ServerBase {
suspend fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean = true
suspend fun serverHandshake(input: ByteReadChannelWrapper, output: ByteWriteChannelWrapper, log: Logger) = true
}
fun <T> runBlockingWithTimeout(timeout: Long = AUTH_TIMEOUT_IN_MILLISECONDS, block: suspend () -> T) =

View File

@@ -115,19 +115,18 @@ class ByteReadChannelWrapper(readChannel: ByteReadChannel, private val log: Logg
}
/** first reads <t>length</t> token (4 bytes), then reads <t>length</t> bytes and returns deserialized object */
suspend fun nextObject() = async {
suspend fun nextObject(): Any? {
val obj = CompletableDeferred<Any?>()
readActor.send(SerObjectQuery(obj))
val result = obj.await()
if (result is Server.ServerDownMessage<*>) {
throw IOException("connection closed by server")
}
result
return result
}
}
class ByteWriteChannelWrapper(writeChannel: ByteWriteChannel, private val log: Logger) {
private interface WriteActorQuery
@@ -180,7 +179,6 @@ class ByteWriteChannelWrapper(writeChannel: ByteWriteChannel, private val log: L
}
suspend fun printBytesAndLength(length: Int, bytes: ByteArray) {
// println("printBytesAndLength : $length $bytes")
writeActor.send(
ObjectWithLength(
getLengthBytes(length),
@@ -192,14 +190,9 @@ class ByteWriteChannelWrapper(writeChannel: ByteWriteChannel, private val log: L
private suspend fun printObjectImpl(obj: Any?) =
ByteArrayOutputStream().use { bos ->
ObjectOutputStream(bos).use { objOut ->
// println("printObjectImpl : $obj")
// println("obj is ser : ${obj is Serializable}")
objOut.writeObject(obj)
// println("objOut.writeObject : $obj")
objOut.flush()
// println("objOut.flush : $obj")
val bytes = bos.toByteArray()
// println("bytes : $bytes")
printBytesAndLength(bytes.size, bytes)
}
}

View File

@@ -100,10 +100,15 @@ class CompileServiceServerSideImpl(
override val clients = hashMapOf<Socket, Server.ClientInfo>()
override suspend fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean =
runWithTimeout {
getSignatureAndVerify(clientInputChannel, securityData.token, securityData.publicKey)
} ?: false
object KeepAliveServer : Server<ServerBase> {
override val serverSocketWithPort = findCallbackServerSocket()
override val clients = hashMapOf<Socket, Server.ClientInfo>()
}
override suspend fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean = runWithTimeout {
getSignatureAndVerify(clientInputChannel, securityData.token, securityData.publicKey)
} ?: false
override suspend fun serverHandshake(input: ByteReadChannelWrapper, output: ByteWriteChannelWrapper, log: Logger): Boolean {
return tryAcquireHandshakeMessage(input, log) && trySendHandshakeMessage(output, log)
@@ -666,7 +671,7 @@ class CompileServiceServerSideImpl(
return compiler.compile(allKotlinFiles, k2jvmArgs, compilerMessageCollector, changedFiles)
}
override suspend fun leaseReplSession(
override suspend fun leaseReplSession(
aliveFlagPath: String?,
compilerArguments: Array<out String>,
compilationOptions: CompilationOptions,
@@ -751,7 +756,9 @@ class CompileServiceServerSideImpl(
System.setProperty(KOTLIN_COMPILER_ENVIRONMENT_KEEPALIVE_PROPERTY, "true")
// TODO UNCOMMENT THIS : this.toRMIServer(daemonOptions, compilerId) // also create RMI server in order to support old clients
// this.toRMIServer(daemonOptions, compilerId)
this.toRMIServer(daemonOptions, compilerId)
KeepAliveServer.runServer()
timer.schedule(10) {
exceptionLoggingTimerThread(info = "initiateElections") {

View File

@@ -745,17 +745,17 @@ class CompilerDaemonTest : KotlinIntegrationTestBase() {
withLogFile("kotlin-daemon-test") { logFile ->
val cfg: String =
"handlers = java.util.logging.FileHandler\n" +
"java.util.logging.FileHandler.level = ALL\n" +
"java.util.logging.FileHandler.formatter = java.util.logging.SimpleFormatter\n" +
"java.util.logging.FileHandler.encoding = UTF-8\n" +
"java.util.logging.FileHandler.limit = 0\n" + // if file is provided - disabled, else - 1Mb
"java.util.logging.FileHandler.count = 1\n" +
"java.util.logging.FileHandler.append = true\n" +
"java.util.logging.FileHandler.pattern = ${logFile.loggerCompatiblePath}\n" +
"java.util.logging.SimpleFormatter.format = %1\$tF %1\$tT.%1\$tL [%3\$s] %4\$s: %5\$s%n\n"
LogManager.getLogManager().readConfiguration(cfg.byteInputStream())
// val cfg: String =
// "handlers = java.util.logging.FileHandler\n" +
// "java.util.logging.FileHandler.level = ALL\n" +
// "java.util.logging.FileHandler.formatter = java.util.logging.SimpleFormatter\n" +
// "java.util.logging.FileHandler.encoding = UTF-8\n" +
// "java.util.logging.FileHandler.limit = 0\n" + // if file is provided - disabled, else - 1Mb
// "java.util.logging.FileHandler.count = 1\n" +
// "java.util.logging.FileHandler.append = true\n" +
// "java.util.logging.FileHandler.pattern = ${logFile.loggerCompatiblePath}\n" +
// "java.util.logging.SimpleFormatter.format = %1\$tF %1\$tT.%1\$tL [%3\$s] %4\$s: %5\$s%n\n"
// LogManager.getLogManager().readConfiguration(cfg.byteInputStream())
val daemonJVMOptions = makeTestDaemonJvmOptions(logFile, xmx = -1)
@@ -806,7 +806,7 @@ class CompilerDaemonTest : KotlinIntegrationTestBase() {
}
private object ParallelStartParams {
const val threads = 32
const val threads = 10
const val performCompilation = false
const val connectionFailedErr = -100
}