diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.kt index d1a551a1428..fb5e9e38e55 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/inline/InlineCodegen.kt @@ -20,7 +20,8 @@ import com.intellij.psi.PsiElement import com.intellij.util.ArrayUtil import org.jetbrains.kotlin.builtins.BuiltInsPackageFragment import org.jetbrains.kotlin.codegen.* -import org.jetbrains.kotlin.codegen.AsmUtil.* +import org.jetbrains.kotlin.codegen.AsmUtil.getMethodAsmFlags +import org.jetbrains.kotlin.codegen.AsmUtil.isPrimitive import org.jetbrains.kotlin.codegen.coroutines.createMethodNodeForSuspendCoroutineOrReturn import org.jetbrains.kotlin.codegen.coroutines.isBuiltInSuspendCoroutineOrReturnInJvm import org.jetbrains.kotlin.codegen.intrinsics.bytecode @@ -29,7 +30,9 @@ import org.jetbrains.kotlin.codegen.state.GenerationState import org.jetbrains.kotlin.descriptors.* import org.jetbrains.kotlin.descriptors.annotations.isInlineOnly import org.jetbrains.kotlin.name.Name -import org.jetbrains.kotlin.psi.* +import org.jetbrains.kotlin.psi.KtCallableReferenceExpression +import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.psi.KtPsiUtil import org.jetbrains.kotlin.renderer.DescriptorRenderer import org.jetbrains.kotlin.resolve.BindingContext import org.jetbrains.kotlin.resolve.DescriptorToSourceUtils @@ -47,13 +50,15 @@ import org.jetbrains.kotlin.serialization.deserialization.descriptors.Deserializ import org.jetbrains.kotlin.types.expressions.DoubleColonLHS import org.jetbrains.kotlin.types.expressions.ExpressionTypingUtils.isFunctionLiteral import org.jetbrains.kotlin.types.expressions.LabelResolver +import org.jetbrains.kotlin.utils.DFS +import org.jetbrains.kotlin.utils.addIfNotNull import org.jetbrains.org.objectweb.asm.Opcodes import org.jetbrains.org.objectweb.asm.Type import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter import org.jetbrains.org.objectweb.asm.commons.Method -import org.jetbrains.org.objectweb.asm.tree.AbstractInsnNode -import org.jetbrains.org.objectweb.asm.tree.MethodNode +import org.jetbrains.org.objectweb.asm.tree.* import java.util.* +import kotlin.collections.HashSet abstract class InlineCodegen( protected val codegen: T, @@ -179,6 +184,46 @@ abstract class InlineCodegen( } } + private fun canSkipStackSpillingOnInline(methodNode: MethodNode): Boolean { + // Stack spilling before inline function 'f' call is required if: + // - 'f' is a suspend function + // - 'f' has try-catch blocks + // - 'f' has loops + // + // Instead of checking for loops precisely, we just check if there are any backward jumps - + // that is, a jump from instruction #i to instruction #j where j < i + + if (functionDescriptor.isSuspend) return false + if (methodNode.tryCatchBlocks.isNotEmpty()) return false + + fun isBackwardJump(fromIndex: Int, toLabel: LabelNode) = + methodNode.instructions.indexOf(toLabel) < fromIndex + + val insns = methodNode.instructions.toArray() + for (i in insns.indices) { + val insn = insns[i] + when (insn) { + is JumpInsnNode -> + if (isBackwardJump(i, insn.label)) return false + + is LookupSwitchInsnNode -> { + insn.dflt?.let { + if (isBackwardJump(i, it)) return false + } + if (insn.labels.any { isBackwardJump(i, it) }) return false + } + + is TableSwitchInsnNode -> { + insn.dflt?.let { + if (isBackwardJump(i, it)) return false + } + if (insn.labels.any { isBackwardJump(i, it) }) return false + } + } + } + + return true + } protected fun inlineCall(nodeAndSmap: SMAPAndMethodNode, callDefault: Boolean): InlineResult { assert(delayedHiddenWriting == null) { "'putHiddenParamsIntoLocals' should be called after 'processAndPutHiddenParameters(true)'" } @@ -201,7 +246,11 @@ abstract class InlineCodegen( //through generation captured parameters will be added to invocationParamBuilder putClosureParametersOnStack() - addInlineMarker(codegen.v, true) + val shouldSpillStack = !canSkipStackSpillingOnInline(node) + + if (shouldSpillStack) { + addInlineMarker(codegen.v, true) + } val parameters = invocationParamBuilder.buildParameters() @@ -239,7 +288,9 @@ abstract class InlineCodegen( adapter.accept(MethodBodyVisitor(codegen.v)) - addInlineMarker(codegen.v, false) + if (shouldSpillStack) { + addInlineMarker(codegen.v, false) + } defaultSourceMapper.callSiteMarker = null diff --git a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/differentTypes.kt b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/differentTypes.kt index 2952171d04a..1682539ff9f 100644 --- a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/differentTypes.kt +++ b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/differentTypes.kt @@ -1,4 +1,9 @@ -inline fun bar(x: Int, y: Long, z: Byte, s: String) = x.toString() + y.toString() + z.toString() + s +inline fun runAfterLoop(fn: () -> T): T { + for (i in 1..2); + return fn() +} + +inline fun bar(x: Int, y: Long, z: Byte, s: String) = runAfterLoop { x.toString() + y.toString() + z.toString() + s } fun foobar(x: Int, y: Long, s: String, z: Byte) = x.toString() + y.toString() + s + z.toString() @@ -6,10 +11,10 @@ fun foo() : String { return foobar(1, 2L, bar(3, 4L, 5.toByte(), "6"), 7.toByte()) } -// 3 ISTORE -// 7 ILOAD +// 9 ISTORE +// 13 ILOAD // 2 ASTORE -// 6 ALOAD +// 8 ALOAD // 2 LSTORE // 4 LLOAD // 1 MAXLOCALS = 10 diff --git a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/primitiveMerge.kt b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/primitiveMerge.kt index 45cdb304666..ce12901171a 100644 --- a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/primitiveMerge.kt +++ b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/primitiveMerge.kt @@ -1,17 +1,22 @@ +inline fun runAfterLoop(fn: () -> T): T { + for (i in 1..2); + return fn() +} + fun bar() : Boolean = true fun foobar(x: Boolean, y: String, z: String) = x.toString() + y + z -inline fun foo() = "-" +inline fun foo() = runAfterLoop { "-" } fun test() { val result = foobar(if (1 == 1) true else bar(), foo(), "OK") } -// 1 ISTORE -// 2 ILOAD +// 7 ISTORE +// 8 ILOAD // 2 ASTORE -// 5 ALOAD +// 7 ALOAD // 1 MAXLOCALS = 3 // 1 MAXLOCALS = 4 // 0 InlineMarker diff --git a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/simple.kt b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/simple.kt index 0c2217e729a..2d1b0ebe189 100644 --- a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/simple.kt +++ b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/simple.kt @@ -9,6 +9,6 @@ fun foo() : Int { return foobar(1, bar(2), 3) } -// 3 ISTORE -// 7 ILOAD +// 1 ISTORE +// 5 ILOAD // 0 InlineMarker diff --git a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/unreachableMarker.kt b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/unreachableMarker.kt index 4cace50d983..ece1e46f660 100644 --- a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/unreachableMarker.kt +++ b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/unreachableMarker.kt @@ -1,5 +1,10 @@ +inline fun runAfterLoop(fn: () -> T): T { + for (i in 1..2); + return fn() +} + inline fun bar(block: () -> String) : String { - return block() + return runAfterLoop(block) } inline fun bar2() : String { @@ -16,6 +21,6 @@ fun foo() : String { ) } -// 10 ALOAD +// 12 ALOAD // 2 ASTORE // 0 InlineMarker diff --git a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/withLambda.kt b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/withLambda.kt index f960dd7af6b..5568b8ae2a8 100644 --- a/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/withLambda.kt +++ b/compiler/testData/codegen/bytecodeText/storeStackBeforeInline/withLambda.kt @@ -1,4 +1,10 @@ -inline fun bar(x: String, block: (String) -> String) = "def" + block(x) +inline fun runAfterLoop(fn: () -> T): T { + for (i in 1..2); + return fn() +} + +inline fun bar(x: String, block: (String) -> String) = runAfterLoop { "def" + block(x) } + fun foobar(x: String, y: String, z: String) = x + y + z fun foo() : String { @@ -6,6 +12,6 @@ fun foo() : String { } // 6 ASTORE -// 16 ALOAD +// 18 ALOAD // 1 MAXLOCALS = 7 // 0 InlineMarker