diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/FirConstChecks.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/FirConstChecks.kt index fcfb389bdac..68f01c0d9e7 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/FirConstChecks.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/checkers/FirConstChecks.kt @@ -24,6 +24,8 @@ import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlin.util.OperatorNameConventions +private val compileTimeExtensionFunctions = setOf(Name.identifier("floorDiv"), Name.identifier("mod")) + fun ConeKotlinType.canBeUsedForConstVal(): Boolean = with(lowerBoundIfFlexible()) { isPrimitive || isString || isUnsignedType } internal fun checkConstantArguments( @@ -145,6 +147,15 @@ internal fun checkConstantArguments( checkConstantArguments(exp, session)?.let { return it } } } + in compileTimeExtensionFunctions -> { + if (calleeReference !is FirResolvedNamedReference) return ConstantArgumentKind.NOT_CONST + val symbol = calleeReference.resolvedSymbol as? FirCallableSymbol + if (symbol?.callableId?.packageName?.asString() != "kotlin") return ConstantArgumentKind.NOT_CONST + + for (exp in (expression as FirCall).arguments.plus(expression.extensionReceiver)) { + checkConstantArguments(exp, session)?.let { return it } + } + } else -> { if (expression.arguments.isNotEmpty() || calleeReference !is FirResolvedNamedReference) { return ConstantArgumentKind.NOT_CONST diff --git a/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/builtins/IrBuiltInsMapGenerated.kt b/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/builtins/IrBuiltInsMapGenerated.kt index c056b140747..315366e52dc 100644 --- a/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/builtins/IrBuiltInsMapGenerated.kt +++ b/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/builtins/IrBuiltInsMapGenerated.kt @@ -643,6 +643,66 @@ internal fun interpretBinaryFunction(name: String, typeA: String, typeB: String, "OROR" -> when (typeA) { "Boolean" -> if (typeB == "Boolean") return (a as Boolean) || (b as Boolean) } + "mod" -> when (typeA) { + "Byte" -> when (typeB) { + "Byte" -> return (a as Byte).mod(b as Byte) + "Short" -> return (a as Byte).mod(b as Short) + "Int" -> return (a as Byte).mod(b as Int) + "Long" -> return (a as Byte).mod(b as Long) + } + "Short" -> when (typeB) { + "Byte" -> return (a as Short).mod(b as Byte) + "Short" -> return (a as Short).mod(b as Short) + "Int" -> return (a as Short).mod(b as Int) + "Long" -> return (a as Short).mod(b as Long) + } + "Int" -> when (typeB) { + "Byte" -> return (a as Int).mod(b as Byte) + "Short" -> return (a as Int).mod(b as Short) + "Int" -> return (a as Int).mod(b as Int) + "Long" -> return (a as Int).mod(b as Long) + } + "Long" -> when (typeB) { + "Byte" -> return (a as Long).mod(b as Byte) + "Short" -> return (a as Long).mod(b as Short) + "Int" -> return (a as Long).mod(b as Int) + "Long" -> return (a as Long).mod(b as Long) + } + "Float" -> when (typeB) { + "Float" -> return (a as Float).mod(b as Float) + "Double" -> return (a as Float).mod(b as Double) + } + "Double" -> when (typeB) { + "Float" -> return (a as Double).mod(b as Float) + "Double" -> return (a as Double).mod(b as Double) + } + } + "floorDiv" -> when (typeA) { + "Byte" -> when (typeB) { + "Byte" -> return (a as Byte).floorDiv(b as Byte) + "Short" -> return (a as Byte).floorDiv(b as Short) + "Int" -> return (a as Byte).floorDiv(b as Int) + "Long" -> return (a as Byte).floorDiv(b as Long) + } + "Short" -> when (typeB) { + "Byte" -> return (a as Short).floorDiv(b as Byte) + "Short" -> return (a as Short).floorDiv(b as Short) + "Int" -> return (a as Short).floorDiv(b as Int) + "Long" -> return (a as Short).floorDiv(b as Long) + } + "Int" -> when (typeB) { + "Byte" -> return (a as Int).floorDiv(b as Byte) + "Short" -> return (a as Int).floorDiv(b as Short) + "Int" -> return (a as Int).floorDiv(b as Int) + "Long" -> return (a as Int).floorDiv(b as Long) + } + "Long" -> when (typeB) { + "Byte" -> return (a as Long).floorDiv(b as Byte) + "Short" -> return (a as Long).floorDiv(b as Short) + "Int" -> return (a as Long).floorDiv(b as Int) + "Long" -> return (a as Long).floorDiv(b as Long) + } + } } throw InterpreterMethodNotFoundError("Unknown function: $name($typeA, $typeB)") } diff --git a/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/checker/EvaluationMode.kt b/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/checker/EvaluationMode.kt index 4e832619fc2..424e4c1c093 100644 --- a/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/checker/EvaluationMode.kt +++ b/compiler/ir/ir.interpreter/src/org/jetbrains/kotlin/ir/interpreter/checker/EvaluationMode.kt @@ -51,19 +51,22 @@ enum class EvaluationMode(protected val mustCheckBody: Boolean) { ONLY_BUILTINS(mustCheckBody = false) { private val forbiddenMethodsOnPrimitives = setOf("inc", "dec", "rangeTo", "hashCode") private val forbiddenMethodsOnStrings = setOf("subSequence", "hashCode", "") + private val allowedExtensionFunctions = setOf("kotlin.floorDiv", "kotlin.mod", "kotlin.NumbersKt.floorDiv", "kotlin.NumbersKt.mod") override fun canEvaluateFunction(function: IrFunction, expression: IrCall?): Boolean { if ((function as? IrSimpleFunction)?.correspondingPropertySymbol?.owner?.isConst == true) return true - val parent = function.parentClassOrNull ?: return false - val parentType = parent.defaultType + val fqName = function.fqNameWhenAvailable?.asString() + val parent = function.parentClassOrNull + val parentType = parent?.defaultType return when { + parentType == null -> fqName in allowedExtensionFunctions parentType.isPrimitiveType() -> function.name.asString() !in forbiddenMethodsOnPrimitives parentType.isString() -> function.name.asString() !in forbiddenMethodsOnStrings parentType.isAny() -> function.name.asString() == "toString" && expression?.dispatchReceiver !is IrGetObjectValue parent.isObject -> parent.parentClassOrNull?.defaultType?.let { it.isPrimitiveType() || it.isUnsigned() } == true parentType.isUnsignedType() && function is IrConstructor -> true - else -> false + else -> fqName in allowedExtensionFunctions } } }; diff --git a/compiler/testData/codegen/box/evaluate/floorDiv.kt b/compiler/testData/codegen/box/evaluate/floorDiv.kt index 0203349e3fd..14c2e600aaf 100644 --- a/compiler/testData/codegen/box/evaluate/floorDiv.kt +++ b/compiler/testData/codegen/box/evaluate/floorDiv.kt @@ -1,7 +1,6 @@ // !LANGUAGE: -ApproximateIntegerLiteralTypesInReceiverPosition // IGNORE_FIR_DIAGNOSTICS // TARGET_BACKEND: JVM -// IGNORE_BACKEND_FIR: JVM_IR // WITH_RUNTIME diff --git a/compiler/testData/codegen/box/evaluate/mod.kt b/compiler/testData/codegen/box/evaluate/mod.kt index 81d61d82a5e..e79d694d659 100644 --- a/compiler/testData/codegen/box/evaluate/mod.kt +++ b/compiler/testData/codegen/box/evaluate/mod.kt @@ -1,7 +1,6 @@ // !LANGUAGE: -ApproximateIntegerLiteralTypesInReceiverPosition // IGNORE_FIR_DIAGNOSTICS // TARGET_BACKEND: JVM -// IGNORE_BACKEND_FIR: JVM_IR // WITH_RUNTIME @@ -15,10 +14,10 @@ annotation class Ann( val p6: Float ) -val prop1: Byte = 10.mod(2) -val prop2: Short = 10.mod(-3) +val prop1: Byte = 10.mod(2.toByte()) +val prop2: Short = 10.mod((-3).toShort()) val prop3: Int = (-10).mod(4) -val prop4: Long = (-10).mod(-5) +val prop4: Long = (-10).mod((-5).toLong()) val prop5: Double = 0.25.mod(-100.0) val prop6: Float = 100f.mod(0.33f) diff --git a/generators/interpreter/GenerateInterpreterMap.kt b/generators/interpreter/GenerateInterpreterMap.kt index e0c039c7168..54c72b6d1df 100644 --- a/generators/interpreter/GenerateInterpreterMap.kt +++ b/generators/interpreter/GenerateInterpreterMap.kt @@ -19,7 +19,6 @@ import org.jetbrains.kotlin.ir.IrBuiltIns import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI import org.jetbrains.kotlin.ir.declarations.impl.IrFactoryImpl import org.jetbrains.kotlin.ir.descriptors.IrBuiltInsOverDescriptors -import org.jetbrains.kotlin.ir.symbols.IrFileSymbol import org.jetbrains.kotlin.ir.types.impl.originalKotlinType import org.jetbrains.kotlin.ir.util.IdSignature import org.jetbrains.kotlin.ir.util.IdSignatureComposer @@ -58,7 +57,7 @@ fun generateMap(): String { this += Operation("toString", listOf("Any?"), customExpression = "a?.toString() ?: \"null\"") }) - generateInterpretBinaryFunction(p, getOperationMap(2) + getBinaryIrOperationMap(irBuiltIns)) + generateInterpretBinaryFunction(p, getOperationMap(2) + getBinaryIrOperationMap(irBuiltIns) + getExtensionOperationMap()) generateInterpretTernaryFunction(p, getOperationMap(3)) @@ -247,6 +246,28 @@ private fun getBinaryIrOperationMap(irBuiltIns: IrBuiltIns): List { return operationMap } +// TODO can be drop after serialization introduction +private fun getExtensionOperationMap(): List { + val operationMap = mutableListOf() + val integerTypes = listOf(PrimitiveType.BYTE, PrimitiveType.SHORT, PrimitiveType.INT, PrimitiveType.LONG).map { it.typeName.asString() } + val fpTypes = listOf(PrimitiveType.FLOAT, PrimitiveType.DOUBLE).map { it.typeName.asString() } + + for (type in integerTypes) { + for (otherType in integerTypes) { + operationMap.add(Operation("mod", listOf(type, otherType), isFunction = true)) + operationMap.add(Operation("floorDiv", listOf(type, otherType), isFunction = true)) + } + } + + for (type in fpTypes) { + for (otherType in fpTypes) { + operationMap.add(Operation("mod", listOf(type, otherType), isFunction = true)) + } + } + + return operationMap +} + private fun getIrMethodSymbolByName(methodName: String): String { return when (methodName) { BuiltInOperatorNames.LESS -> "<"