Compare commits

...

3 Commits

Author SHA1 Message Date
Dmitry Petrov
ef461d4a3a JVM_IR KT-36646 fuze primitive equality with safe call 2021-04-27 16:22:53 +03:00
Dmitry Petrov
eaf870bfd4 JVM_IR update test for KT-36637 2021-04-27 16:22:52 +03:00
Dmitry Petrov
e8982765c8 JVM_IR use static 'hashCode' for boxed primitives on JVM 1.8+ 2021-04-27 16:22:52 +03:00
8 changed files with 210 additions and 47 deletions

View File

@@ -39,7 +39,7 @@ class RedundantBoxingMethodTransformer(private val generationState: GenerationSt
override fun transform(internalClassName: String, node: MethodNode) {
val interpreter = RedundantBoxingInterpreter(node.instructions, generationState)
val frames = MethodTransformer.analyze(internalClassName, node, interpreter)
val frames = analyze(internalClassName, node, interpreter)
interpretPopInstructionsForBoxedValues(interpreter, node, frames)
@@ -168,7 +168,8 @@ class RedundantBoxingMethodTransformer(private val generationState: GenerationSt
val frame = frames[i] ?: continue
val insn = insnList[i]
if ((insn.opcode == Opcodes.ASTORE || insn.opcode == Opcodes.ALOAD) &&
(insn as VarInsnNode).`var` == localVariableNode.index) {
(insn as VarInsnNode).`var` == localVariableNode.index
) {
if (insn.getOpcode() == Opcodes.ASTORE) {
values.add(frame.top()!!)
} else {

View File

@@ -85,7 +85,7 @@ class Equals(val operator: IElementType) : IntrinsicMethod() {
// what comparison means. The optimization does not apply to `object == primitive` as equals
// could be overridden for the object.
if ((opToken == IrStatementOrigin.EQEQ || opToken == IrStatementOrigin.EXCLEQ) &&
((AsmUtil.isIntOrLongPrimitive(leftType) && !AsmUtil.isPrimitive(rightType)) ||
((AsmUtil.isIntOrLongPrimitive(leftType) && !isPrimitive(rightType)) ||
(AsmUtil.isIntOrLongPrimitive(rightType) && AsmUtil.isBoxedPrimitiveType(leftType)))
) {
val aValue = a.accept(codegen, data).materializedAt(leftType, a.type)

View File

@@ -27,31 +27,38 @@ import org.jetbrains.kotlin.ir.util.render
import org.jetbrains.org.objectweb.asm.Opcodes
import org.jetbrains.org.objectweb.asm.Type
// TODO Implement hashCode on primitive types as a lowering.
object HashCode : IntrinsicMethod() {
override fun invoke(expression: IrFunctionAccessExpression, codegen: ExpressionCodegen, data: BlockInfo) = with(codegen) {
val receiver = expression.dispatchReceiver ?: error("No receiver for hashCode: ${expression.render()}")
val result = receiver.accept(this, data).materialized()
val receiverIrType = receiver.type
val receiverJvmType = typeMapper.mapType(receiverIrType)
val receiverValue = receiver.accept(this, data).materialized()
val receiverType = receiverValue.type
val target = context.state.target
when {
irFunction.origin == JvmLoweredDeclarationOrigin.INLINE_CLASS_GENERATED_IMPL_METHOD ||
irFunction.origin == IrDeclarationOrigin.GENERATED_DATA_CLASS_MEMBER ->
DescriptorAsmUtil.genHashCode(mv, mv, result.type, target)
target == JvmTarget.JVM_1_6 || !AsmUtil.isPrimitive(result.type) -> {
result.materializeAtBoxed(receiver.type)
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "hashCode", "()I", false)
irFunction.origin == IrDeclarationOrigin.GENERATED_DATA_CLASS_MEMBER -> {
// TODO generate or lower IR for data class / inline class 'hashCode'?
DescriptorAsmUtil.genHashCode(mv, mv, receiverType, target)
}
else -> {
val boxedType = AsmUtil.boxType(result.type)
target >= JvmTarget.JVM_1_8 && AsmUtil.isPrimitive(receiverJvmType) -> {
val boxedType = AsmUtil.boxPrimitiveType(receiverJvmType)
?: throw AssertionError("Primitive type expected: $receiverJvmType")
receiverValue.materializeAt(receiverJvmType, receiverIrType)
mv.visitMethodInsn(
Opcodes.INVOKESTATIC,
boxedType.internalName,
"hashCode",
Type.getMethodDescriptor(Type.INT_TYPE, result.type),
Type.getMethodDescriptor(Type.INT_TYPE, receiverJvmType),
false
)
}
else -> {
receiverValue.materializeAtBoxed(receiverIrType)
mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "hashCode", "()I", false)
}
}
MaterialValue(codegen, Type.INT_TYPE, codegen.context.irBuiltIns.intType)
}
}

View File

@@ -13,18 +13,6 @@ import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.org.objectweb.asm.Label
import org.jetbrains.org.objectweb.asm.Type
private fun ExpressionCodegen.checkTopValueForNull() {
mv.dup()
if (state.unifiedNullChecks) {
mv.invokestatic(IntrinsicMethods.INTRINSICS_CLASS_NAME, "checkNotNull", "(Ljava/lang/Object;)V", false)
} else {
val ifNonNullLabel = Label()
mv.ifnonnull(ifNonNullLabel)
mv.invokestatic(IntrinsicMethods.INTRINSICS_CLASS_NAME, "throwNpe", "()V", false)
mv.mark(ifNonNullLabel)
}
}
object IrCheckNotNull : IntrinsicMethod() {
override fun invoke(expression: IrFunctionAccessExpression, codegen: ExpressionCodegen, data: BlockInfo): PromisedValue? {
val arg0 = expression.getValueArgument(0)!!.accept(codegen, data)
@@ -37,4 +25,16 @@ object IrCheckNotNull : IntrinsicMethod() {
arg0.materialized().also { codegen.checkTopValueForNull() }.discard()
}
}
private fun ExpressionCodegen.checkTopValueForNull() {
mv.dup()
if (state.unifiedNullChecks) {
mv.invokestatic(IntrinsicMethods.INTRINSICS_CLASS_NAME, "checkNotNull", "(Ljava/lang/Object;)V", false)
} else {
val ifNonNullLabel = Label()
mv.ifnonnull(ifNonNullLabel)
mv.invokestatic(IntrinsicMethods.INTRINSICS_CLASS_NAME, "throwNpe", "()V", false)
mv.mark(ifNonNullLabel)
}
}
}

View File

@@ -10,16 +10,17 @@ import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.backend.common.lower.irBlock
import org.jetbrains.kotlin.backend.common.phaser.makeIrFilePhase
import org.jetbrains.kotlin.backend.jvm.JvmBackendContext
import org.jetbrains.kotlin.codegen.intrinsics.Not
import org.jetbrains.kotlin.backend.jvm.ir.createJvmIrBuilder
import org.jetbrains.kotlin.codegen.AsmUtil
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.builders.irGetField
import org.jetbrains.kotlin.ir.builders.irSetField
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.IrBlockImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.symbols.impl.IrPublicSymbolBase
import org.jetbrains.kotlin.ir.types.*
import org.jetbrains.kotlin.ir.util.*
@@ -64,6 +65,42 @@ class JvmOptimizationLowering(val context: JvmBackendContext) : FileLoweringPass
else -> null
}
private class SafeCallInfo(
val scopeSymbol: IrSymbol,
val tmpVal: IrVariable,
val ifNullBranch: IrBranch,
val ifNotNullBranch: IrBranch
)
private fun parseSafeCall(expression: IrExpression): SafeCallInfo? {
val block = expression as? IrBlock ?: return null
if (block.origin != IrStatementOrigin.SAFE_CALL) return null
if (block.statements.size != 2) return null
val tmpVal = block.statements[0] as? IrVariable ?: return null
val scopeOwner = tmpVal.parent as? IrDeclaration ?: return null
val scopeSymbol = scopeOwner.symbol
val whenExpr = block.statements[1] as? IrWhen ?: return null
if (whenExpr.branches.size != 2) return null
val ifNullBranch = whenExpr.branches[0]
val ifNullBranchCondition = ifNullBranch.condition
if (ifNullBranchCondition !is IrCall) return null
if (ifNullBranchCondition.symbol != context.irBuiltIns.eqeqSymbol) return null
val arg0 = ifNullBranchCondition.getValueArgument(0)
if (arg0 !is IrGetValue || arg0.symbol != tmpVal.symbol) return null
val arg1 = ifNullBranchCondition.getValueArgument(1)
if (arg1 !is IrConst<*> || arg1.value != null) return null
val ifNullBranchResult = ifNullBranch.result
if (ifNullBranchResult !is IrConst<*> || ifNullBranchResult.value != null) return null
val ifNotNullBranch = whenExpr.branches[1]
return SafeCallInfo(scopeSymbol, tmpVal, ifNullBranch, ifNotNullBranch)
}
private fun IrType.isJvmPrimitive(): Boolean =
// TODO get rid of type mapper (take care of '@EnhancedNullability', maybe some other stuff).
AsmUtil.isPrimitive(context.typeMapper.mapType(this))
override fun lower(irFile: IrFile) {
val transformer = object : IrElementTransformer<IrClass?> {
@@ -105,20 +142,93 @@ class JvmOptimizationLowering(val context: JvmBackendContext) : FileLoweringPass
}
getOperandsIfCallToEQEQOrEquals(expression)?.let { (left, right) ->
return when {
left.isNullConst() && right.isNullConst() ->
IrConstImpl.constTrue(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType)
if (left.isNullConst() && right.isNullConst())
return IrConstImpl.constTrue(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType)
left.isNullConst() && right is IrConst<*> || right.isNullConst() && left is IrConst<*> ->
IrConstImpl.constFalse(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType)
if (left.isNullConst() && right is IrConst<*> || right.isNullConst() && left is IrConst<*>)
return IrConstImpl.constFalse(expression.startOffset, expression.endOffset, context.irBuiltIns.booleanType)
else -> expression
val safeCallLeft = parseSafeCall(left)
if (safeCallLeft != null && right.type.isJvmPrimitive()) {
return rewriteSafeCallEqeqPrimitive(safeCallLeft, right, expression)
}
val safeCallRight = parseSafeCall(right)
if (safeCallRight != null && left.type.isJvmPrimitive()) {
return rewritePrimitiveEqeqSafeCall(left, safeCallRight, expression)
}
return expression
}
return expression
}
private fun rewriteSafeCallEqeqPrimitive(safeCall: SafeCallInfo, primitive: IrExpression, eqeqCall: IrCall): IrExpression =
context.createJvmIrBuilder(safeCall.scopeSymbol).run {
// Fuze safe call with primitive equality to avoid boxing the primitive.
// 'a?.<...> == p' becomes:
// {
// val tmp = a
// when {
// tmp == null -> false
// else -> tmp == p
// }
// }
irBlock {
+safeCall.tmpVal
+irWhen(
eqeqCall.type,
listOf(
irBranch(safeCall.ifNullBranch.condition, irFalse()),
irElseBranch(
irCall(eqeqCall.symbol).apply {
putValueArgument(0, safeCall.ifNotNullBranch.result)
putValueArgument(1, primitive)
}
)
)
)
}
}
private fun rewritePrimitiveEqeqSafeCall(primitive: IrExpression, safeCall: SafeCallInfo, eqeqCall: IrCall): IrExpression =
context.createJvmIrBuilder(safeCall.scopeSymbol).run {
// Fuze safe call with primitive equality to avoid boxing the primitive.
// 'p == a?.<...>' becomes:
// {
// val tmp_p = p // should evaluate 'p' before 'a'
// val tmp = a
// when {
// tmp == null -> false
// else -> tmp_p == tmp
// }
// }
// 'tmp_p' above could be elided if 'p' is a variable or a constant.
irBlock {
val lhs =
if (primitive.isTrivial())
primitive
else {
val tmp = irTemporary(primitive)
irGet(tmp)
}
+safeCall.tmpVal
+irWhen(
eqeqCall.type,
listOf(
irBranch(safeCall.ifNullBranch.condition, irFalse()),
irElseBranch(
irCall(eqeqCall.symbol).apply {
putValueArgument(0, lhs)
putValueArgument(1, safeCall.ifNotNullBranch.result)
}
)
)
)
}
}
private fun IrType.isByteOrShort() = isByte() || isShort()
// For `==` and `!=`, get rid of safe calls to convert `Byte?` or `Short?` to `Int?`.

View File

@@ -1,11 +1,58 @@
// IGNORE_BACKEND: JVM_IR
// IGNORE_BACKEND_FIR: JVM_IR
fun foo() {
val x: Int? = 6
val hc = x!!.hashCode()
fun testBoolean(): Int {
val b: Boolean? = true
return b!!.hashCode()
}
fun testByte(): Int {
val b: Byte? = 1.toByte()
return b!!.hashCode()
}
fun testChar(): Int {
val c: Char? = 'x'
return c!!.hashCode()
}
fun testShort(): Int {
val s: Short? = 1.toShort()
return s!!.hashCode()
}
fun testInt(): Int {
val i: Int? = 42
return i!!.hashCode()
}
fun testLong(): Int {
val l: Long? = 42L
return l!!.hashCode()
}
fun testFloat(): Int {
val f: Float? = 0.0f
return f!!.hashCode()
}
fun testDouble(): Int {
val d: Double? = 0.0
return d!!.hashCode()
}
// 1 java/lang/Boolean.hashCode \(Z\)I
// 1 java/lang/Character.hashCode \(C\)I
// 1 java/lang/Byte.hashCode \(B\)I
// 1 java/lang/Short.hashCode \(S\)I
// 1 java/lang/Integer.hashCode \(I\)I
// 0 java/lang/Integer.valueOf
// 1 java/lang/Long.hashCode \(J\)I
// 1 java/lang/Float.hashCode \(F\)I
// 1 java/lang/Double.hashCode \(D\)I
// 0 valueOf
// 0 byteValue
// 0 shortValue
// 0 intValue
// 0 longValue
// 0 floatValue
// 0 doubleValue
// 0 charValue

View File

@@ -1,7 +1,3 @@
// IGNORE_BACKEND_FIR: JVM_IR
// IGNORE_BACKEND: JVM_IR
// TODO KT-36646 Don't box primitive values in equality comparison with nullable primitive values in JVM_IR
fun Long.id() = this
fun String.drop2() = if (length >= 2) subSequence(2, length) else null

View File

@@ -1,7 +1,3 @@
// IGNORE_BACKEND_FIR: JVM_IR
// IGNORE_BACKEND: JVM_IR
// TODO KT-36637 Trivial closure optimizatin in JVM_IR
fun test() {
fun local(){
@@ -21,6 +17,12 @@ fun test() {
(::local)()
}
// JVM_TEMPLATES
// 3 GETSTATIC ConstClosureOptimizationKt\$test\$1\.INSTANCE
// 1 GETSTATIC ConstClosureOptimizationKt\$test\$2\.INSTANCE
// 1 GETSTATIC ConstClosureOptimizationKt\$test\$3\.INSTANCE
// JVM_IR_TEMPLATES
// 1 GETSTATIC ConstClosureOptimizationKt\$test\$1.INSTANCE
// 1 GETSTATIC ConstClosureOptimizationKt\$test\$2.INSTANCE
// 1 GETSTATIC ConstClosureOptimizationKt\$test\$local\$1.INSTANCE