Support interpretation for floorDiv and mod functions

This commit is contained in:
Ivan Kylchik
2021-08-04 16:59:02 +03:00
committed by TeamCityServer
parent f9607292b5
commit b85a796492
6 changed files with 103 additions and 10 deletions

View File

@@ -24,6 +24,8 @@ import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.kotlin.util.OperatorNameConventions
private val compileTimeExtensionFunctions = setOf(Name.identifier("floorDiv"), Name.identifier("mod"))
fun ConeKotlinType.canBeUsedForConstVal(): Boolean = with(lowerBoundIfFlexible()) { isPrimitive || isString || isUnsignedType } fun ConeKotlinType.canBeUsedForConstVal(): Boolean = with(lowerBoundIfFlexible()) { isPrimitive || isString || isUnsignedType }
internal fun checkConstantArguments( internal fun checkConstantArguments(
@@ -145,6 +147,15 @@ internal fun checkConstantArguments(
checkConstantArguments(exp, session)?.let { return it } checkConstantArguments(exp, session)?.let { return it }
} }
} }
in compileTimeExtensionFunctions -> {
if (calleeReference !is FirResolvedNamedReference) return ConstantArgumentKind.NOT_CONST
val symbol = calleeReference.resolvedSymbol as? FirCallableSymbol
if (symbol?.callableId?.packageName?.asString() != "kotlin") return ConstantArgumentKind.NOT_CONST
for (exp in (expression as FirCall).arguments.plus(expression.extensionReceiver)) {
checkConstantArguments(exp, session)?.let { return it }
}
}
else -> { else -> {
if (expression.arguments.isNotEmpty() || calleeReference !is FirResolvedNamedReference) { if (expression.arguments.isNotEmpty() || calleeReference !is FirResolvedNamedReference) {
return ConstantArgumentKind.NOT_CONST return ConstantArgumentKind.NOT_CONST

View File

@@ -643,6 +643,66 @@ internal fun interpretBinaryFunction(name: String, typeA: String, typeB: String,
"OROR" -> when (typeA) { "OROR" -> when (typeA) {
"Boolean" -> if (typeB == "Boolean") return (a as Boolean) || (b as Boolean) "Boolean" -> if (typeB == "Boolean") return (a as Boolean) || (b as Boolean)
} }
"mod" -> when (typeA) {
"Byte" -> when (typeB) {
"Byte" -> return (a as Byte).mod(b as Byte)
"Short" -> return (a as Byte).mod(b as Short)
"Int" -> return (a as Byte).mod(b as Int)
"Long" -> return (a as Byte).mod(b as Long)
}
"Short" -> when (typeB) {
"Byte" -> return (a as Short).mod(b as Byte)
"Short" -> return (a as Short).mod(b as Short)
"Int" -> return (a as Short).mod(b as Int)
"Long" -> return (a as Short).mod(b as Long)
}
"Int" -> when (typeB) {
"Byte" -> return (a as Int).mod(b as Byte)
"Short" -> return (a as Int).mod(b as Short)
"Int" -> return (a as Int).mod(b as Int)
"Long" -> return (a as Int).mod(b as Long)
}
"Long" -> when (typeB) {
"Byte" -> return (a as Long).mod(b as Byte)
"Short" -> return (a as Long).mod(b as Short)
"Int" -> return (a as Long).mod(b as Int)
"Long" -> return (a as Long).mod(b as Long)
}
"Float" -> when (typeB) {
"Float" -> return (a as Float).mod(b as Float)
"Double" -> return (a as Float).mod(b as Double)
}
"Double" -> when (typeB) {
"Float" -> return (a as Double).mod(b as Float)
"Double" -> return (a as Double).mod(b as Double)
}
}
"floorDiv" -> when (typeA) {
"Byte" -> when (typeB) {
"Byte" -> return (a as Byte).floorDiv(b as Byte)
"Short" -> return (a as Byte).floorDiv(b as Short)
"Int" -> return (a as Byte).floorDiv(b as Int)
"Long" -> return (a as Byte).floorDiv(b as Long)
}
"Short" -> when (typeB) {
"Byte" -> return (a as Short).floorDiv(b as Byte)
"Short" -> return (a as Short).floorDiv(b as Short)
"Int" -> return (a as Short).floorDiv(b as Int)
"Long" -> return (a as Short).floorDiv(b as Long)
}
"Int" -> when (typeB) {
"Byte" -> return (a as Int).floorDiv(b as Byte)
"Short" -> return (a as Int).floorDiv(b as Short)
"Int" -> return (a as Int).floorDiv(b as Int)
"Long" -> return (a as Int).floorDiv(b as Long)
}
"Long" -> when (typeB) {
"Byte" -> return (a as Long).floorDiv(b as Byte)
"Short" -> return (a as Long).floorDiv(b as Short)
"Int" -> return (a as Long).floorDiv(b as Int)
"Long" -> return (a as Long).floorDiv(b as Long)
}
}
} }
throw InterpreterMethodNotFoundError("Unknown function: $name($typeA, $typeB)") throw InterpreterMethodNotFoundError("Unknown function: $name($typeA, $typeB)")
} }

View File

@@ -51,19 +51,22 @@ enum class EvaluationMode(protected val mustCheckBody: Boolean) {
ONLY_BUILTINS(mustCheckBody = false) { ONLY_BUILTINS(mustCheckBody = false) {
private val forbiddenMethodsOnPrimitives = setOf("inc", "dec", "rangeTo", "hashCode") private val forbiddenMethodsOnPrimitives = setOf("inc", "dec", "rangeTo", "hashCode")
private val forbiddenMethodsOnStrings = setOf("subSequence", "hashCode", "<init>") private val forbiddenMethodsOnStrings = setOf("subSequence", "hashCode", "<init>")
private val allowedExtensionFunctions = setOf("kotlin.floorDiv", "kotlin.mod", "kotlin.NumbersKt.floorDiv", "kotlin.NumbersKt.mod")
override fun canEvaluateFunction(function: IrFunction, expression: IrCall?): Boolean { override fun canEvaluateFunction(function: IrFunction, expression: IrCall?): Boolean {
if ((function as? IrSimpleFunction)?.correspondingPropertySymbol?.owner?.isConst == true) return true if ((function as? IrSimpleFunction)?.correspondingPropertySymbol?.owner?.isConst == true) return true
val parent = function.parentClassOrNull ?: return false val fqName = function.fqNameWhenAvailable?.asString()
val parentType = parent.defaultType val parent = function.parentClassOrNull
val parentType = parent?.defaultType
return when { return when {
parentType == null -> fqName in allowedExtensionFunctions
parentType.isPrimitiveType() -> function.name.asString() !in forbiddenMethodsOnPrimitives parentType.isPrimitiveType() -> function.name.asString() !in forbiddenMethodsOnPrimitives
parentType.isString() -> function.name.asString() !in forbiddenMethodsOnStrings parentType.isString() -> function.name.asString() !in forbiddenMethodsOnStrings
parentType.isAny() -> function.name.asString() == "toString" && expression?.dispatchReceiver !is IrGetObjectValue parentType.isAny() -> function.name.asString() == "toString" && expression?.dispatchReceiver !is IrGetObjectValue
parent.isObject -> parent.parentClassOrNull?.defaultType?.let { it.isPrimitiveType() || it.isUnsigned() } == true parent.isObject -> parent.parentClassOrNull?.defaultType?.let { it.isPrimitiveType() || it.isUnsigned() } == true
parentType.isUnsignedType() && function is IrConstructor -> true parentType.isUnsignedType() && function is IrConstructor -> true
else -> false else -> fqName in allowedExtensionFunctions
} }
} }
}; };

View File

@@ -1,7 +1,6 @@
// !LANGUAGE: -ApproximateIntegerLiteralTypesInReceiverPosition // !LANGUAGE: -ApproximateIntegerLiteralTypesInReceiverPosition
// IGNORE_FIR_DIAGNOSTICS // IGNORE_FIR_DIAGNOSTICS
// TARGET_BACKEND: JVM // TARGET_BACKEND: JVM
// IGNORE_BACKEND_FIR: JVM_IR
// WITH_RUNTIME // WITH_RUNTIME

View File

@@ -1,7 +1,6 @@
// !LANGUAGE: -ApproximateIntegerLiteralTypesInReceiverPosition // !LANGUAGE: -ApproximateIntegerLiteralTypesInReceiverPosition
// IGNORE_FIR_DIAGNOSTICS // IGNORE_FIR_DIAGNOSTICS
// TARGET_BACKEND: JVM // TARGET_BACKEND: JVM
// IGNORE_BACKEND_FIR: JVM_IR
// WITH_RUNTIME // WITH_RUNTIME
@@ -15,10 +14,10 @@ annotation class Ann(
val p6: Float val p6: Float
) )
val prop1: Byte = 10.mod(2) val prop1: Byte = 10.mod(2.toByte())
val prop2: Short = 10.mod(-3) val prop2: Short = 10.mod((-3).toShort())
val prop3: Int = (-10).mod(4) val prop3: Int = (-10).mod(4)
val prop4: Long = (-10).mod(-5) val prop4: Long = (-10).mod((-5).toLong())
val prop5: Double = 0.25.mod(-100.0) val prop5: Double = 0.25.mod(-100.0)
val prop6: Float = 100f.mod(0.33f) val prop6: Float = 100f.mod(0.33f)

View File

@@ -19,7 +19,6 @@ import org.jetbrains.kotlin.ir.IrBuiltIns
import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
import org.jetbrains.kotlin.ir.declarations.impl.IrFactoryImpl import org.jetbrains.kotlin.ir.declarations.impl.IrFactoryImpl
import org.jetbrains.kotlin.ir.descriptors.IrBuiltInsOverDescriptors import org.jetbrains.kotlin.ir.descriptors.IrBuiltInsOverDescriptors
import org.jetbrains.kotlin.ir.symbols.IrFileSymbol
import org.jetbrains.kotlin.ir.types.impl.originalKotlinType import org.jetbrains.kotlin.ir.types.impl.originalKotlinType
import org.jetbrains.kotlin.ir.util.IdSignature import org.jetbrains.kotlin.ir.util.IdSignature
import org.jetbrains.kotlin.ir.util.IdSignatureComposer import org.jetbrains.kotlin.ir.util.IdSignatureComposer
@@ -58,7 +57,7 @@ fun generateMap(): String {
this += Operation("toString", listOf("Any?"), customExpression = "a?.toString() ?: \"null\"") this += Operation("toString", listOf("Any?"), customExpression = "a?.toString() ?: \"null\"")
}) })
generateInterpretBinaryFunction(p, getOperationMap(2) + getBinaryIrOperationMap(irBuiltIns)) generateInterpretBinaryFunction(p, getOperationMap(2) + getBinaryIrOperationMap(irBuiltIns) + getExtensionOperationMap())
generateInterpretTernaryFunction(p, getOperationMap(3)) generateInterpretTernaryFunction(p, getOperationMap(3))
@@ -247,6 +246,28 @@ private fun getBinaryIrOperationMap(irBuiltIns: IrBuiltIns): List<Operation> {
return operationMap return operationMap
} }
// TODO can be drop after serialization introduction
private fun getExtensionOperationMap(): List<Operation> {
val operationMap = mutableListOf<Operation>()
val integerTypes = listOf(PrimitiveType.BYTE, PrimitiveType.SHORT, PrimitiveType.INT, PrimitiveType.LONG).map { it.typeName.asString() }
val fpTypes = listOf(PrimitiveType.FLOAT, PrimitiveType.DOUBLE).map { it.typeName.asString() }
for (type in integerTypes) {
for (otherType in integerTypes) {
operationMap.add(Operation("mod", listOf(type, otherType), isFunction = true))
operationMap.add(Operation("floorDiv", listOf(type, otherType), isFunction = true))
}
}
for (type in fpTypes) {
for (otherType in fpTypes) {
operationMap.add(Operation("mod", listOf(type, otherType), isFunction = true))
}
}
return operationMap
}
private fun getIrMethodSymbolByName(methodName: String): String { private fun getIrMethodSymbolByName(methodName: String): String {
return when (methodName) { return when (methodName) {
BuiltInOperatorNames.LESS -> "<" BuiltInOperatorNames.LESS -> "<"