mirror of
https://github.com/jlengrand/kotlin.git
synced 2026-03-22 08:31:31 +00:00
fix new security mechanisms
This commit is contained in:
@@ -22,5 +22,5 @@ val RESULTS_SERVER_PORTS_RANGE_END: Int = 17000
|
||||
|
||||
val COMPILER_DAEMON_CLASS_FQN_EXPERIMENTAL: String = "org.jetbrains.kotlin.daemon.experimental.KotlinCompileDaemon"
|
||||
|
||||
val FIRST_HANDSHAKE_BYTE_TOKEN = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8)
|
||||
val FIRST_HANDSHAKE_BYTE_TOKEN = byteArrayOf(1, 2, 3, 4)
|
||||
val AUTH_TIMEOUT_IN_MILLISECONDS = 1000L
|
||||
@@ -58,7 +58,7 @@ suspend fun walkDaemonsAsync(
|
||||
"found daemon on socketPort $port ($relativeAge ms old), trying to connect"
|
||||
)
|
||||
log.info("found daemon on socketPort $port ($relativeAge ms old), trying to connect")
|
||||
val daemon = tryConnectToDaemonAsync(port, report, useRMI, useSockets)
|
||||
val daemon = tryConnectToDaemonAsync(port, report, file, useRMI, useSockets)
|
||||
log.info("daemon = $daemon (port= $port)")
|
||||
// cleaning orphaned file; note: daemon should shut itself down if it detects that the runServer file is deleted
|
||||
if (daemon == null) {
|
||||
@@ -121,12 +121,17 @@ private inline fun tryConnectToDaemonByRMI(port: Int, report: (DaemonReportCateg
|
||||
return null
|
||||
}
|
||||
|
||||
private inline fun tryConnectToDaemonBySockets(port: Int, report: (DaemonReportCategory, String) -> Unit): CompileServiceClientSide? {
|
||||
private inline fun tryConnectToDaemonBySockets(
|
||||
port: Int,
|
||||
file: File,
|
||||
report: (DaemonReportCategory, String) -> Unit
|
||||
): CompileServiceClientSide? {
|
||||
try {
|
||||
log.info("tryConnectToDaemonBySockets(port = $port)")
|
||||
val daemon = CompileServiceClientSideImpl(
|
||||
port,
|
||||
LoopbackNetworkInterface.loopbackInetAddressName
|
||||
LoopbackNetworkInterface.loopbackInetAddressName,
|
||||
file
|
||||
)
|
||||
log.info("daemon($port) = $daemon")
|
||||
log.info("daemon($port) connecting to server...")
|
||||
@@ -142,10 +147,11 @@ private inline fun tryConnectToDaemonBySockets(port: Int, report: (DaemonReportC
|
||||
private fun tryConnectToDaemonAsync(
|
||||
port: Int,
|
||||
report: (DaemonReportCategory, String) -> Unit,
|
||||
file: File,
|
||||
useRMI: Boolean = true,
|
||||
useSockets: Boolean = true
|
||||
): CompileServiceClientSide? =
|
||||
useSockets.takeIf { it }?.let { tryConnectToDaemonBySockets(port, report) }
|
||||
useSockets.takeIf { it }?.let { tryConnectToDaemonBySockets(port, file, report) }
|
||||
?: (useRMI.takeIf { it }?.let { tryConnectToDaemonByRMI(port, report) })
|
||||
|
||||
private const val validFlagFileKeywordChars = "abcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
package org.jetbrains.kotlin.daemon.common.experimental
|
||||
|
||||
import kotlinx.coroutines.experimental.runBlocking
|
||||
import org.jetbrains.kotlin.cli.common.repl.ReplCheckResult
|
||||
import org.jetbrains.kotlin.cli.common.repl.ReplCodeLine
|
||||
import org.jetbrains.kotlin.cli.common.repl.ReplCompileResult
|
||||
@@ -19,16 +20,24 @@ import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.Serv
|
||||
import java.io.File
|
||||
import java.util.logging.Logger
|
||||
|
||||
interface CompileServiceClientSide: CompileServiceAsync, Client {
|
||||
interface CompileServiceClientSide : CompileServiceAsync, Client {
|
||||
val serverPort: Int
|
||||
}
|
||||
|
||||
|
||||
class CompileServiceClientSideImpl(
|
||||
override val serverPort: Int,
|
||||
val serverHost: String
|
||||
) : CompileServiceClientSide, Client by DefaultClient(serverPort, serverHost) {
|
||||
|
||||
val serverHost: String,
|
||||
val serverFile: File
|
||||
) : CompileServiceClientSide,
|
||||
Client by DefaultClient(serverPort, serverHost, { serverOutputChannel ->
|
||||
runBlocking {
|
||||
println("in authoriseOnServer(serverFile=$serverFile)")
|
||||
val signature = serverFile.inputStream().use(::readTokenKeyPairAndSign)
|
||||
sendSignature(serverOutputChannel, signature)
|
||||
}
|
||||
}) {
|
||||
|
||||
val log = Logger.getLogger("CompileServiceClientSideImpl")
|
||||
|
||||
override suspend fun compile(
|
||||
@@ -182,7 +191,7 @@ class CompileServiceClientSideImpl(
|
||||
).await()
|
||||
return readMessage<CallResult<ReplCompileResult>>().await()
|
||||
}
|
||||
|
||||
|
||||
// Query messages:
|
||||
|
||||
class CheckCompilerIdMessage(val expectedCompilerId: CompilerId) : Server.Message<CompileServiceServerSide> {
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
|
||||
* that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
@file:Suppress("CAST_NEVER_SUCCEEDS")
|
||||
|
||||
package org.jetbrains.kotlin.daemon.common.experimental
|
||||
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.ByteReadChannelWrapper
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.ByteWriteChannelWrapper
|
||||
import java.io.FileInputStream
|
||||
import java.io.FileOutputStream
|
||||
import java.io.ObjectInputStream
|
||||
import java.io.ObjectOutputStream
|
||||
import java.security.*
|
||||
import java.security.spec.InvalidKeySpecException
|
||||
import java.security.spec.X509EncodedKeySpec
|
||||
|
||||
|
||||
const val SECURITY_TOKEN_SIZE = 128
|
||||
|
||||
private val secureRandom = SecureRandom.getInstance("SHA1PRNG", "SUN");
|
||||
private val pairGenerator = KeyPairGenerator.getInstance("DSA", "SUN")
|
||||
private val keyFactory = KeyFactory.getInstance("DSA", "SUN")
|
||||
|
||||
private fun generateSecurityToken(): ByteArray {
|
||||
val tokenBuffer = ByteArray(SECURITY_TOKEN_SIZE)
|
||||
secureRandom.nextBytes(tokenBuffer)
|
||||
return tokenBuffer
|
||||
}
|
||||
|
||||
data class SecurityData(val privateKey: PrivateKey, val publicKey: PublicKey, val token: ByteArray)
|
||||
fun generateKeysAndToken() = pairGenerator.generateKeyPair().let {
|
||||
SecurityData(it.private, it.public, generateSecurityToken())
|
||||
}
|
||||
|
||||
private fun FileInputStream.readAllBytes(): ByteArray {
|
||||
val bytes = arrayListOf<Byte>()
|
||||
val buffer = ByteArray(1024)
|
||||
var bytesRead = 0
|
||||
while (bytesRead != -1) {
|
||||
bytesRead = this.read(buffer, 0, 1024)
|
||||
bytes.addAll(buffer.toList())
|
||||
}
|
||||
return ByteArray(bytes.size, bytes::get)
|
||||
}
|
||||
|
||||
private fun FileInputStream.readBytesFixedLength(n: Int): ByteArray {
|
||||
val buffer = ByteArray(n)
|
||||
var bytesRead = 0
|
||||
while (bytesRead != n) {
|
||||
bytesRead += this.read(buffer, bytesRead, n - bytesRead)
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
// server part :
|
||||
fun sendTokenKeyPair(output: FileOutputStream, token: ByteArray, privateKey: PrivateKey) {
|
||||
output.write(token)
|
||||
ObjectOutputStream(output).use {
|
||||
it.writeObject(privateKey)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun getSignatureAndVerify(input: ByteReadChannelWrapper, expectedToken: ByteArray, publicKey: PublicKey): Boolean {
|
||||
val signature = input.readBytes(input.getLength())
|
||||
val dsa = Signature.getInstance("SHA1withDSA", "SUN")
|
||||
dsa.initVerify(publicKey)
|
||||
dsa.update(expectedToken, 0, SECURITY_TOKEN_SIZE)
|
||||
val verified = dsa.verify(signature)
|
||||
log.info("verified : $verified")
|
||||
return verified
|
||||
}
|
||||
|
||||
|
||||
// client part :
|
||||
fun readTokenKeyPairAndSign(input: FileInputStream): ByteArray {
|
||||
val token = input.readBytesFixedLength(SECURITY_TOKEN_SIZE)
|
||||
val privateKey = ObjectInputStream(input).use(ObjectInputStream::readObject) as PrivateKey
|
||||
val dsa = Signature.getInstance("SHA1withDSA", "SUN")
|
||||
dsa.initSign(privateKey)
|
||||
dsa.update(token, 0, SECURITY_TOKEN_SIZE)
|
||||
return dsa.sign()
|
||||
}
|
||||
|
||||
suspend fun sendSignature(output: ByteWriteChannelWrapper, signature: ByteArray) {
|
||||
output.printLength(signature.size)
|
||||
output.printBytes(signature)
|
||||
}
|
||||
@@ -2,8 +2,8 @@ package org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure
|
||||
|
||||
import io.ktor.network.sockets.Socket
|
||||
import kotlinx.coroutines.experimental.*
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.FIRST_HANDSHAKE_BYTE_TOKEN
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.LoopbackNetworkInterface
|
||||
import sun.net.ConnectionResetException
|
||||
import java.beans.Transient
|
||||
import java.io.Serializable
|
||||
import java.util.logging.Logger
|
||||
@@ -15,13 +15,13 @@ interface Client : Serializable, AutoCloseable {
|
||||
fun sendMessage(msg: Any): Deferred<Unit>
|
||||
fun <T> readMessage(): Deferred<T>
|
||||
|
||||
fun f() {}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
class DefaultClient(
|
||||
val serverPort: Int,
|
||||
val serverHost: String = LoopbackNetworkInterface.loopbackInetAddressName
|
||||
val serverHost: String = LoopbackNetworkInterface.loopbackInetAddressName,
|
||||
val authorizeOnServer: (ByteWriteChannelWrapper) -> Unit = {}
|
||||
) : Client {
|
||||
|
||||
val log: Logger
|
||||
@@ -38,7 +38,6 @@ class DefaultClient(
|
||||
private var socket: Socket? = null
|
||||
@Transient get
|
||||
@Transient set
|
||||
|
||||
override fun close() {
|
||||
socket?.close()
|
||||
}
|
||||
@@ -47,6 +46,7 @@ class DefaultClient(
|
||||
|
||||
override fun <T> readMessage() = async { input.nextObject() as T }
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun connectToServer() {
|
||||
runBlocking(Unconfined) {
|
||||
log.info("connectToServer (port = $serverPort | host = $serverHost)")
|
||||
@@ -64,8 +64,10 @@ class DefaultClient(
|
||||
log.info("OK serv.openIO() |port=$serverPort|")
|
||||
input = it.input
|
||||
output = it.output
|
||||
sendHandshakeMessage(output)
|
||||
tryAcquireHandshakeMessage(input, log)
|
||||
if (!tryAcquireHandshakeMessage(input, log) || !trySendHandshakeMessage(output)) {
|
||||
throw ConnectionResetException("failed to establish connection with server (handshake failed)")
|
||||
}
|
||||
authorizeOnServer(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,11 @@ import kotlinx.coroutines.experimental.delay
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.AUTH_TIMEOUT_IN_MILLISECONDS
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.FIRST_HANDSHAKE_BYTE_TOKEN
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.LoopbackNetworkInterface
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.log
|
||||
import sun.net.ConnectionResetException
|
||||
import java.io.Serializable
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.TimeoutException
|
||||
import java.util.logging.Logger
|
||||
|
||||
/*
|
||||
@@ -42,11 +45,12 @@ interface Server<out T : ServerBase> : ServerBase {
|
||||
|
||||
suspend fun attachClient(client: Socket): Deferred<State> = async {
|
||||
val (input, output) = client.openIO(log)
|
||||
try {
|
||||
tryAcquireHandshakeMessage(input, log)
|
||||
sendHandshakeMessage(output)
|
||||
} catch (e: Throwable) {
|
||||
log.info("NO TOKEN")
|
||||
if (!trySendHandshakeMessage(output) || !tryAcquireHandshakeMessage(input, log)) {
|
||||
log.info("failed to establish connection with client (handshake failed)")
|
||||
return@async Server.State.CLOSED
|
||||
}
|
||||
if (!securityCheck(input)) {
|
||||
log.info("failed to check securitay")
|
||||
return@async Server.State.CLOSED
|
||||
}
|
||||
var finalState = Server.State.WORKING
|
||||
@@ -100,21 +104,55 @@ interface Server<out T : ServerBase> : ServerBase {
|
||||
}
|
||||
}
|
||||
|
||||
fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean = true
|
||||
}
|
||||
|
||||
@Throws(Exception::class)
|
||||
suspend fun tryAcquireHandshakeMessage(input: ByteReadChannelWrapper, log: Logger): Boolean {
|
||||
val bytesAsync = async { input.readBytes(FIRST_HANDSHAKE_BYTE_TOKEN.size) }
|
||||
delay(AUTH_TIMEOUT_IN_MILLISECONDS, TimeUnit.MILLISECONDS)
|
||||
val bytes = bytesAsync.getCompleted()
|
||||
log.info("bytes : ${bytes.toList()}")
|
||||
if (bytes.zip(FIRST_HANDSHAKE_BYTE_TOKEN).any { it.first != it.second }) {
|
||||
log.info("BAD TOKEN")
|
||||
@Throws(TimeoutException::class)
|
||||
suspend fun <T> runWithTimeout(
|
||||
timeout: Long = AUTH_TIMEOUT_IN_MILLISECONDS,
|
||||
unit: TimeUnit = TimeUnit.MILLISECONDS,
|
||||
block: suspend () -> T
|
||||
): T {
|
||||
val asyncRes = async { block() }
|
||||
delay(timeout, unit)
|
||||
return try {
|
||||
asyncRes.getCompleted()
|
||||
} catch (e: IllegalStateException) {
|
||||
throw TimeoutException("failed to get coroutine's value after given timeout")
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(ConnectionResetException::class)
|
||||
suspend fun tryAcquireHandshakeMessage(input: ByteReadChannelWrapper, log: Logger) : Boolean {
|
||||
log.info("tryAcquireHandshakeMessage")
|
||||
val bytes: ByteArray = try {
|
||||
runWithTimeout {
|
||||
input.readBytes(FIRST_HANDSHAKE_BYTE_TOKEN.size)
|
||||
}
|
||||
} catch (e: TimeoutException) {
|
||||
log.info("no token received")
|
||||
return false
|
||||
}
|
||||
log.info("bytes : ${bytes.toList()}")
|
||||
if (bytes.zip(FIRST_HANDSHAKE_BYTE_TOKEN).any { it.first != it.second }) {
|
||||
log.info("invalid token received")
|
||||
return false
|
||||
}
|
||||
log.info("tryAcquireHandshakeMessage - SUCCESS")
|
||||
return true
|
||||
}
|
||||
|
||||
suspend fun sendHandshakeMessage(output: ByteWriteChannelWrapper) {
|
||||
output.printBytes(FIRST_HANDSHAKE_BYTE_TOKEN)
|
||||
@Throws(ConnectionResetException::class)
|
||||
suspend fun trySendHandshakeMessage(output: ByteWriteChannelWrapper) : Boolean {
|
||||
log.info("trySendHandshakeMessage")
|
||||
try {
|
||||
runWithTimeout {
|
||||
output.printBytes(FIRST_HANDSHAKE_BYTE_TOKEN)
|
||||
}
|
||||
} catch (e: TimeoutException) {
|
||||
log.info("trySendHandshakeMessage - FAIL")
|
||||
return false
|
||||
}
|
||||
log.info("trySendHandshakeMessage - SUCCESS")
|
||||
return true
|
||||
}
|
||||
@@ -37,7 +37,7 @@ class ByteReadChannelWrapper(private val readChannel: ByteReadChannel, private v
|
||||
log.info("object : ${if (it is CompileService.CallResult<*> && it.isGood) it.get() else it}")
|
||||
}
|
||||
|
||||
private suspend fun getLength(): Int {
|
||||
suspend fun getLength(): Int {
|
||||
log.info("length : ")
|
||||
val packet = readBytes(4)
|
||||
log.info("length : ${packet.toList()}")
|
||||
@@ -78,7 +78,7 @@ class ByteWriteChannelWrapper(private val writeChannel: ByteWriteChannel, privat
|
||||
}
|
||||
|
||||
|
||||
private suspend fun printLength(length: Int) = printBytes(
|
||||
suspend fun printLength(length: Int) = printBytes(
|
||||
ByteBuffer
|
||||
.allocate(4)
|
||||
.putInt(length)
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.jetbrains.kotlin.daemon.KotlinJvmReplService
|
||||
import org.jetbrains.kotlin.daemon.LazyClasspathWatcher
|
||||
import org.jetbrains.kotlin.daemon.common.*
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.*
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.ByteReadChannelWrapper
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.Server
|
||||
import org.jetbrains.kotlin.daemon.incremental.experimental.RemoteAnnotationsFileUpdaterAsync
|
||||
import org.jetbrains.kotlin.daemon.incremental.experimental.RemoteArtifactChangesProviderAsync
|
||||
@@ -52,6 +53,7 @@ import java.io.PrintStream
|
||||
import java.rmi.RemoteException
|
||||
import java.util.*
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.TimeoutException
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock
|
||||
@@ -60,6 +62,9 @@ import java.util.logging.Logger
|
||||
import kotlin.concurrent.read
|
||||
import kotlin.concurrent.schedule
|
||||
import kotlin.concurrent.write
|
||||
import org.jetbrains.kotlin.daemon.common.experimental.socketInfrastructure.runWithTimeout
|
||||
import java.security.PrivateKey
|
||||
import java.security.PublicKey
|
||||
|
||||
fun nowSeconds() = TimeUnit.NANOSECONDS.toSeconds(System.nanoTime())
|
||||
|
||||
@@ -96,6 +101,18 @@ class CompileServiceServerSideImpl(
|
||||
val onShutdown: () -> Unit
|
||||
) : CompileServiceServerSide {
|
||||
|
||||
override fun securityCheck(clientInputChannel: ByteReadChannelWrapper): Boolean =
|
||||
runBlocking {
|
||||
try {
|
||||
val verified = runWithTimeout {
|
||||
getSignatureAndVerify(clientInputChannel, token, publicKey)
|
||||
}
|
||||
verified
|
||||
} catch (e: TimeoutException) {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
constructor(
|
||||
serverPort: Int,
|
||||
compilerId: CompilerId,
|
||||
@@ -263,6 +280,8 @@ class CompileServiceServerSideImpl(
|
||||
private val rwlock = ReentrantReadWriteLock()
|
||||
|
||||
private var runFile: File
|
||||
private var token: ByteArray
|
||||
private var publicKey: PublicKey
|
||||
|
||||
init {
|
||||
val runFileDir = File(daemonOptions.runFilesPathOrDefault)
|
||||
@@ -281,6 +300,15 @@ class CompileServiceServerSideImpl(
|
||||
} catch (e: Throwable) {
|
||||
throw IllegalStateException("Unable to create runServer file '${runFile.absolutePath}'", e)
|
||||
}
|
||||
var privateKey: PrivateKey?
|
||||
generateKeysAndToken().let {
|
||||
privateKey = it.privateKey
|
||||
publicKey = it.publicKey
|
||||
token = it.token
|
||||
}
|
||||
runFile.outputStream().use {
|
||||
sendTokenKeyPair(it, token, privateKey!!)
|
||||
}
|
||||
runFile.deleteOnExit()
|
||||
log.info("last_init_end")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user