fix new security mechanisms

This commit is contained in:
Vadim Brilyantov
2018-03-28 19:12:56 +03:00
parent 8ebe8c2daa
commit 19b60ed4ad
8 changed files with 206 additions and 33 deletions

View File

@@ -22,5 +22,5 @@ val RESULTS_SERVER_PORTS_RANGE_END: Int = 17000
val COMPILER_DAEMON_CLASS_FQN_EXPERIMENTAL: String = "org.jetbrains.kotlin.daemon.experimental.KotlinCompileDaemon"
val FIRST_HANDSHAKE_BYTE_TOKEN = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8)
val FIRST_HANDSHAKE_BYTE_TOKEN = byteArrayOf(1, 2, 3, 4)
val AUTH_TIMEOUT_IN_MILLISECONDS = 1000L

View File

@@ -58,7 +58,7 @@ suspend fun walkDaemonsAsync(
"found daemon on socketPort $port ($relativeAge ms old), trying to connect"
)
log.info("found daemon on socketPort $port ($relativeAge ms old), trying to connect")
val daemon = tryConnectToDaemonAsync(port, report, useRMI, useSockets)
val daemon = tryConnectToDaemonAsync(port, report, file, useRMI, useSockets)
log.info("daemon = $daemon (port= $port)")
// cleaning orphaned file; note: daemon should shut itself down if it detects that the runServer file is deleted
if (daemon == null) {
@@ -121,12 +121,17 @@ private inline fun tryConnectToDaemonByRMI(port: Int, report: (DaemonReportCateg
return null
}
private inline fun tryConnectToDaemonBySockets(port: Int, report: (DaemonReportCategory, String) -> Unit): CompileServiceClientSide? {
private inline fun tryConnectToDaemonBySockets(
port: Int,
file: File,
report: (DaemonReportCategory, String) -> Unit
): CompileServiceClientSide? {
try {
log.info("tryConnectToDaemonBySockets(port = $port)")
val daemon = CompileServiceClientSideImpl(
port,
LoopbackNetworkInterface.loopbackInetAddressName
LoopbackNetworkInterface.loopbackInetAddressName,
file
)
log.info("daemon($port) = $daemon")
log.info("daemon($port) connecting to server...")
@@ -142,10 +147,11 @@ private inline fun tryConnectToDaemonBySockets(port: Int, report: (DaemonReportC
private fun tryConnectToDaemonAsync(
port: Int,
report: (DaemonReportCategory, String) -> Unit,
file: File,
useRMI: Boolean = true,
useSockets: Boolean = true
): CompileServiceClientSide? =
useSockets.takeIf { it }?.let { tryConnectToDaemonBySockets(port, report) }
useSockets.takeIf { it }?.let { tryConnectToDaemonBySockets(port, file, report) }
?: (useRMI.takeIf { it }?.let { tryConnectToDaemonByRMI(port, report) })
private const val validFlagFileKeywordChars = "abcdefghijklmnopqrstuvwxyz0123456789-_"

View File

@@ -7,6 +7,7 @@
package org.jetbrains.kotlin.daemon.common.experimental
import kotlinx.coroutines.experimental.runBlocking
import org.jetbrains.kotlin.cli.common.repl.ReplCheckResult
import org.jetbrains.kotlin.cli.common.repl.ReplCodeLine
import org.jetbrains.kotlin.cli.common.repl.ReplCompileResult
@@ -19,16 +20,24 @@ import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.Serv
import java.io.File
import java.util.logging.Logger
interface CompileServiceClientSide: CompileServiceAsync, Client {
interface CompileServiceClientSide : CompileServiceAsync, Client {
val serverPort: Int
}
class CompileServiceClientSideImpl(
override val serverPort: Int,
val serverHost: String
) : CompileServiceClientSide, Client by DefaultClient(serverPort, serverHost) {
val serverHost: String,
val serverFile: File
) : CompileServiceClientSide,
Client by DefaultClient(serverPort, serverHost, { serverOutputChannel ->
runBlocking {
println("in authoriseOnServer(serverFile=$serverFile)")
val signature = serverFile.inputStream().use(::readTokenKeyPairAndSign)
sendSignature(serverOutputChannel, signature)
}
}) {
val log = Logger.getLogger("CompileServiceClientSideImpl")
override suspend fun compile(
@@ -182,7 +191,7 @@ class CompileServiceClientSideImpl(
).await()
return readMessage<CallResult<ReplCompileResult>>().await()
}
// Query messages:
class CheckCompilerIdMessage(val expectedCompilerId: CompilerId) : Server.Message<CompileServiceServerSide> {

View File

@@ -0,0 +1,90 @@
/*
* Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
* that can be found in the license/LICENSE.txt file.
*/
@file:Suppress("CAST_NEVER_SUCCEEDS")
package org.jetbrains.kotlin.daemon.common.experimental
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.ByteReadChannelWrapper
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.ByteWriteChannelWrapper
import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.security.*
import java.security.spec.InvalidKeySpecException
import java.security.spec.X509EncodedKeySpec
const val SECURITY_TOKEN_SIZE = 128
private val secureRandom = SecureRandom.getInstance("SHA1PRNG", "SUN");
private val pairGenerator = KeyPairGenerator.getInstance("DSA", "SUN")
private val keyFactory = KeyFactory.getInstance("DSA", "SUN")
private fun generateSecurityToken(): ByteArray {
val tokenBuffer = ByteArray(SECURITY_TOKEN_SIZE)
secureRandom.nextBytes(tokenBuffer)
return tokenBuffer
}
data class SecurityData(val privateKey: PrivateKey, val publicKey: PublicKey, val token: ByteArray)
fun generateKeysAndToken() = pairGenerator.generateKeyPair().let {
SecurityData(it.private, it.public, generateSecurityToken())
}
private fun FileInputStream.readAllBytes(): ByteArray {
val bytes = arrayListOf<Byte>()
val buffer = ByteArray(1024)
var bytesRead = 0
while (bytesRead != -1) {
bytesRead = this.read(buffer, 0, 1024)
bytes.addAll(buffer.toList())
}
return ByteArray(bytes.size, bytes::get)
}
private fun FileInputStream.readBytesFixedLength(n: Int): ByteArray {
val buffer = ByteArray(n)
var bytesRead = 0
while (bytesRead != n) {
bytesRead += this.read(buffer, bytesRead, n - bytesRead)
}
return buffer
}
// server part :
fun sendTokenKeyPair(output: FileOutputStream, token: ByteArray, privateKey: PrivateKey) {
output.write(token)
ObjectOutputStream(output).use {
it.writeObject(privateKey)
}
}
suspend fun getSignatureAndVerify(input: ByteReadChannelWrapper, expectedToken: ByteArray, publicKey: PublicKey): Boolean {
val signature = input.readBytes(input.getLength())
val dsa = Signature.getInstance("SHA1withDSA", "SUN")
dsa.initVerify(publicKey)
dsa.update(expectedToken, 0, SECURITY_TOKEN_SIZE)
val verified = dsa.verify(signature)
log.info("verified : $verified")
return verified
}
// client part :
fun readTokenKeyPairAndSign(input: FileInputStream): ByteArray {
val token = input.readBytesFixedLength(SECURITY_TOKEN_SIZE)
val privateKey = ObjectInputStream(input).use(ObjectInputStream::readObject) as PrivateKey
val dsa = Signature.getInstance("SHA1withDSA", "SUN")
dsa.initSign(privateKey)
dsa.update(token, 0, SECURITY_TOKEN_SIZE)
return dsa.sign()
}
suspend fun sendSignature(output: ByteWriteChannelWrapper, signature: ByteArray) {
output.printLength(signature.size)
output.printBytes(signature)
}

View File

@@ -2,8 +2,8 @@ package org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure
import io.ktor.network.sockets.Socket
import kotlinx.coroutines.experimental.*
import org.jetbrains.kotlin.daemon.common.experimental.FIRST_HANDSHAKE_BYTE_TOKEN
import org.jetbrains.kotlin.daemon.common.experimental.LoopbackNetworkInterface
import sun.net.ConnectionResetException
import java.beans.Transient
import java.io.Serializable
import java.util.logging.Logger
@@ -15,13 +15,13 @@ interface Client : Serializable, AutoCloseable {
fun sendMessage(msg: Any): Deferred<Unit>
fun <T> readMessage(): Deferred<T>
fun f() {}
}
@Suppress("UNCHECKED_CAST")
class DefaultClient(
val serverPort: Int,
val serverHost: String = LoopbackNetworkInterface.loopbackInetAddressName
val serverHost: String = LoopbackNetworkInterface.loopbackInetAddressName,
val authorizeOnServer: (ByteWriteChannelWrapper) -> Unit = {}
) : Client {
val log: Logger
@@ -38,7 +38,6 @@ class DefaultClient(
private var socket: Socket? = null
@Transient get
@Transient set
override fun close() {
socket?.close()
}
@@ -47,6 +46,7 @@ class DefaultClient(
override fun <T> readMessage() = async { input.nextObject() as T }
@Throws(Exception::class)
override fun connectToServer() {
runBlocking(Unconfined) {
log.info("connectToServer (port = $serverPort | host = $serverHost)")
@@ -64,8 +64,10 @@ class DefaultClient(
log.info("OK serv.openIO() |port=$serverPort|")
input = it.input
output = it.output
sendHandshakeMessage(output)
tryAcquireHandshakeMessage(input, log)
if (!tryAcquireHandshakeMessage(input, log) || !trySendHandshakeMessage(output)) {
throw ConnectionResetException("failed to establish connection with server (handshake failed)")
}
authorizeOnServer(output)
}
}
}

View File

@@ -7,8 +7,11 @@ import kotlinx.coroutines.experimental.delay
import org.jetbrains.kotlin.daemon.common.experimental.AUTH_TIMEOUT_IN_MILLISECONDS
import org.jetbrains.kotlin.daemon.common.experimental.FIRST_HANDSHAKE_BYTE_TOKEN
import org.jetbrains.kotlin.daemon.common.experimental.LoopbackNetworkInterface
import org.jetbrains.kotlin.daemon.common.experimental.log
import sun.net.ConnectionResetException
import java.io.Serializable
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.logging.Logger
/*
@@ -42,11 +45,12 @@ interface Server<out T : ServerBase> : ServerBase {
suspend fun attachClient(client: Socket): Deferred<State> = async {
val (input, output) = client.openIO(log)
try {
tryAcquireHandshakeMessage(input, log)
sendHandshakeMessage(output)
} catch (e: Throwable) {
log.info("NO TOKEN")
if (!trySendHandshakeMessage(output) || !tryAcquireHandshakeMessage(input, log)) {
log.info("failed to establish connection with client (handshake failed)")
return@async Server.State.CLOSED
}
if (!securityCheck(input)) {
log.info("failed to check securitay")
return@async Server.State.CLOSED
}
var finalState = Server.State.WORKING
@@ -100,21 +104,55 @@ interface Server<out T : ServerBase> : ServerBase {
}
}
fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean = true
}
@Throws(Exception::class)
suspend fun tryAcquireHandshakeMessage(input: ByteReadChannelWrapper, log: Logger): Boolean {
val bytesAsync = async { input.readBytes(FIRST_HANDSHAKE_BYTE_TOKEN.size) }
delay(AUTH_TIMEOUT_IN_MILLISECONDS, TimeUnit.MILLISECONDS)
val bytes = bytesAsync.getCompleted()
log.info("bytes : ${bytes.toList()}")
if (bytes.zip(FIRST_HANDSHAKE_BYTE_TOKEN).any { it.first != it.second }) {
log.info("BAD TOKEN")
@Throws(TimeoutException::class)
suspend fun <T> runWithTimeout(
timeout: Long = AUTH_TIMEOUT_IN_MILLISECONDS,
unit: TimeUnit = TimeUnit.MILLISECONDS,
block: suspend () -> T
): T {
val asyncRes = async { block() }
delay(timeout, unit)
return try {
asyncRes.getCompleted()
} catch (e: IllegalStateException) {
throw TimeoutException("failed to get coroutine's value after given timeout")
}
}
@Throws(ConnectionResetException::class)
suspend fun tryAcquireHandshakeMessage(input: ByteReadChannelWrapper, log: Logger) : Boolean {
log.info("tryAcquireHandshakeMessage")
val bytes: ByteArray = try {
runWithTimeout {
input.readBytes(FIRST_HANDSHAKE_BYTE_TOKEN.size)
}
} catch (e: TimeoutException) {
log.info("no token received")
return false
}
log.info("bytes : ${bytes.toList()}")
if (bytes.zip(FIRST_HANDSHAKE_BYTE_TOKEN).any { it.first != it.second }) {
log.info("invalid token received")
return false
}
log.info("tryAcquireHandshakeMessage - SUCCESS")
return true
}
suspend fun sendHandshakeMessage(output: ByteWriteChannelWrapper) {
output.printBytes(FIRST_HANDSHAKE_BYTE_TOKEN)
@Throws(ConnectionResetException::class)
suspend fun trySendHandshakeMessage(output: ByteWriteChannelWrapper) : Boolean {
log.info("trySendHandshakeMessage")
try {
runWithTimeout {
output.printBytes(FIRST_HANDSHAKE_BYTE_TOKEN)
}
} catch (e: TimeoutException) {
log.info("trySendHandshakeMessage - FAIL")
return false
}
log.info("trySendHandshakeMessage - SUCCESS")
return true
}

View File

@@ -37,7 +37,7 @@ class ByteReadChannelWrapper(private val readChannel: ByteReadChannel, private v
log.info("object : ${if (it is CompileService.CallResult<*> && it.isGood) it.get() else it}")
}
private suspend fun getLength(): Int {
suspend fun getLength(): Int {
log.info("length : ")
val packet = readBytes(4)
log.info("length : ${packet.toList()}")
@@ -78,7 +78,7 @@ class ByteWriteChannelWrapper(private val writeChannel: ByteWriteChannel, privat
}
private suspend fun printLength(length: Int) = printBytes(
suspend fun printLength(length: Int) = printBytes(
ByteBuffer
.allocate(4)
.putInt(length)

View File

@@ -34,6 +34,7 @@ import org.jetbrains.kotlin.daemon.KotlinJvmReplService
import org.jetbrains.kotlin.daemon.LazyClasspathWatcher
import org.jetbrains.kotlin.daemon.common.*
import org.jetbrains.kotlin.daemon.common.experimental.*
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.ByteReadChannelWrapper
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.Server
import org.jetbrains.kotlin.daemon.incremental.experimental.RemoteAnnotationsFileUpdaterAsync
import org.jetbrains.kotlin.daemon.incremental.experimental.RemoteArtifactChangesProviderAsync
@@ -52,6 +53,7 @@ import java.io.PrintStream
import java.rmi.RemoteException
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantReadWriteLock
@@ -60,6 +62,9 @@ import java.util.logging.Logger
import kotlin.concurrent.read
import kotlin.concurrent.schedule
import kotlin.concurrent.write
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.runWithTimeout
import java.security.PrivateKey
import java.security.PublicKey
fun nowSeconds() = TimeUnit.NANOSECONDS.toSeconds(System.nanoTime())
@@ -96,6 +101,18 @@ class CompileServiceServerSideImpl(
val onShutdown: () -> Unit
) : CompileServiceServerSide {
override fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean =
runBlocking {
try {
val verified = runWithTimeout {
getSignatureAndVerify(clientInputChannel, token, publicKey)
}
verified
} catch (e: TimeoutException) {
false
}
}
constructor(
serverPort: Int,
compilerId: CompilerId,
@@ -263,6 +280,8 @@ class CompileServiceServerSideImpl(
private val rwlock = ReentrantReadWriteLock()
private var runFile: File
private var token: ByteArray
private var publicKey: PublicKey
init {
val runFileDir = File(daemonOptions.runFilesPathOrDefault)
@@ -281,6 +300,15 @@ class CompileServiceServerSideImpl(
} catch (e: Throwable) {
throw IllegalStateException("Unable to create runServer file '${runFile.absolutePath}'", e)
}
var privateKey: PrivateKey?
generateKeysAndToken().let {
privateKey = it.privateKey
publicKey = it.publicKey
token = it.token
}
runFile.outputStream().use {
sendTokenKeyPair(it, token, privateKey!!)
}
runFile.deleteOnExit()
log.info("last_init_end")
}