From 07dcc6c6166d5f2c663b199abb99dc2bec87ca35 Mon Sep 17 00:00:00 2001 From: Denis Zharkov Date: Thu, 9 Jun 2016 20:56:03 +0300 Subject: [PATCH] Support 'handleException' operator in JVM backend --- .../kotlin/codegen/ExpressionCodegen.java | 4 +- .../codegen/coroutines/CoroutineCodegen.kt | 23 ++++++ .../CoroutineTransformationClassBuilder.kt | 41 +++++++++- .../coroutines/coroutineCodegenUtil.kt | 45 ++++++++++- .../codegen/box/coroutines/handleException.kt | 75 +++++++++++++++++++ .../codegen/java8/box/asyncException.kt | 31 +++++--- .../codegen/BlackBoxCodegenTestGenerated.java | 6 ++ 7 files changed, 207 insertions(+), 18 deletions(-) create mode 100644 compiler/testData/codegen/box/coroutines/handleException.kt diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java b/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java index 6635071b166..73101de50dd 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/ExpressionCodegen.java @@ -1763,7 +1763,7 @@ public class ExpressionCodegen extends KtVisitor impleme } @NotNull - private StackValue genCoroutineInstanceValueFromResolvedCall(ResolvedCall resolvedCall) { + public StackValue genCoroutineInstanceValueFromResolvedCall(ResolvedCall resolvedCall) { // Currently only handleResult/suspend members are supported ReceiverValue dispatchReceiver = resolvedCall.getDispatchReceiver(); assert dispatchReceiver != null : "Dispatch receiver is null for handleResult/suspend to " + resolvedCall.getResultingDescriptor(); @@ -2686,7 +2686,7 @@ public class ExpressionCodegen extends KtVisitor impleme if (isSuspensionPoint) { v.invokestatic( - CoroutineCodegenUtilKt.SUSPENSION_POINT_MARKER_OWNER, + CoroutineCodegenUtilKt.COROUTINE_MARKER_OWNER, CoroutineCodegenUtilKt.SUSPENSION_POINT_MARKER_NAME, "()V", false); } diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt index b6a49104ddd..94e399de7fd 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineCodegen.kt @@ -22,6 +22,7 @@ import org.jetbrains.kotlin.codegen.context.ClosureContext import org.jetbrains.kotlin.codegen.state.GenerationState import org.jetbrains.kotlin.coroutines.controllerTypeIfCoroutine import org.jetbrains.kotlin.descriptors.* +import org.jetbrains.kotlin.incremental.KotlinLookupLocation import org.jetbrains.kotlin.incremental.components.NoLookupLocation import org.jetbrains.kotlin.name.Name import org.jetbrains.kotlin.psi.KtDeclarationWithBody @@ -35,6 +36,7 @@ import org.jetbrains.kotlin.resolve.jvm.diagnostics.JvmDeclarationOrigin import org.jetbrains.kotlin.resolve.jvm.diagnostics.OtherOrigin import org.jetbrains.kotlin.resolve.jvm.jvmSignature.JvmMethodSignature import org.jetbrains.kotlin.types.KotlinType +import org.jetbrains.kotlin.util.OperatorNameConventions import org.jetbrains.org.objectweb.asm.Opcodes import org.jetbrains.org.objectweb.asm.Type import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter @@ -118,10 +120,31 @@ class CoroutineCodegen( override fun doGenerateBody(codegen: ExpressionCodegen, signature: JvmMethodSignature) { codegen.v.visitAnnotation(CONTINUATION_METHOD_ANNOTATION_DESC, true).visitEnd() super.doGenerateBody(codegen, signature) + generateExceptionHandlingBlock(codegen) } }) } + private fun generateExceptionHandlingBlock(codegen: ExpressionCodegen) { + val handleExceptionFunction = + controllerType.memberScope.getContributedFunctions( + OperatorNameConventions.COROUTINE_HANDLE_EXCEPTION, KotlinLookupLocation(element)).singleOrNull { it.isOperator } + ?: return + + val (resolvedCall, fakeExceptionExpression, fakeThisContinuationException) = + createResolvedCallForHandleExceptionCall(element, handleExceptionFunction, (context as ClosureContext).coroutineDescriptor!!) + + codegen.tempVariables.put(fakeExceptionExpression, StackValue.operation(AsmTypes.OBJECT_TYPE) { + codegen.v.invokestatic(COROUTINE_MARKER_OWNER, HANDLE_EXCEPTION_ARGUMENT_MARKER_NAME, "()Ljava/lang/Object;", false) + }) + + codegen.tempVariables.put(fakeThisContinuationException, codegen.genCoroutineInstanceValueFromResolvedCall(resolvedCall)) + + codegen.v.invokestatic(COROUTINE_MARKER_OWNER, HANDLE_EXCEPTION_MARKER_NAME, "()V", false) + codegen.invokeFunction(resolvedCall, StackValue.none()).put(Type.VOID_TYPE, codegen.v) + codegen.v.areturn(Type.VOID_TYPE) + } + private fun createSynthesizedImplementationByName( name: String, interfaceSupertype: KotlinType, diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformationClassBuilder.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformationClassBuilder.kt index 783b5cf9e3d..8bab4b1b40e 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformationClassBuilder.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/CoroutineTransformationClassBuilder.kt @@ -67,12 +67,16 @@ class CoroutineTransformerMethodVisitor( methodNode.visibleAnnotations.removeAll { it.desc == CONTINUATION_METHOD_ANNOTATION_DESC } val suspensionPoints = collectSuspensionPoints(methodNode) - if (suspensionPoints.isEmpty()) return for (suspensionPoint in suspensionPoints) { splitTryCatchBlocksContainingSuspensionPoint(methodNode, suspensionPoint) } + // Add global exception handler + processHandleExceptionCall(methodNode) + + if (suspensionPoints.isEmpty()) return + // Spill stack to variables before suspension points, try/catch blocks FixStackWithLabelNormalizationMethodTransformer().transform("fake", methodNode) @@ -120,7 +124,7 @@ class CoroutineTransformerMethodVisitor( val suspensionPoints = mutableListOf() for (methodInsn in methodNode.instructions.asSequence().filterIsInstance()) { - if (methodInsn.owner != SUSPENSION_POINT_MARKER_OWNER) continue + if (methodInsn.owner != COROUTINE_MARKER_OWNER) continue when (methodInsn.name) { SUSPENSION_POINT_MARKER_NAME -> { @@ -130,8 +134,6 @@ class CoroutineTransformerMethodVisitor( suspensionPoints.add(SuspensionPoint(methodInsn.next as MethodInsnNode)) } - - else -> error("Unexpected suspension point marker kind '${methodInsn.name}'") } } @@ -317,6 +319,37 @@ class CoroutineTransformerMethodVisitor( return } + + private fun processHandleExceptionCall(methodNode: MethodNode) { + val instructions = methodNode.instructions + val marker = instructions.toArray().firstOrNull() { it.isHandleExceptionMarker() } ?: return + + assert(instructions.toArray().count { it.isHandleExceptionMarker() } == 1) { + "Found more than one handleException markers" + } + + val startLabel = LabelNode() + val endLabel = LabelNode() + instructions.insertBefore(instructions.first, startLabel) + instructions.set(marker, endLabel) + + // NOP is necessary to preserve common invariant: first insn of TCB is always NOP + instructions.insert(startLabel, InsnNode(Opcodes.NOP)) + // ASTORE is needed by the same reason + val maxVar = methodNode.maxLocals++ + instructions.insert(endLabel, VarInsnNode(Opcodes.ASTORE, maxVar)) + + val exceptionArgument = instructions.toArray().single { it.isHandleExceptionMarkerArgument() } + instructions.set(exceptionArgument, VarInsnNode(Opcodes.ALOAD, maxVar)) + + methodNode.tryCatchBlocks.add(TryCatchBlockNode(startLabel, endLabel, endLabel, AsmTypes.JAVA_THROWABLE_TYPE.internalName)) + } + + private fun AbstractInsnNode.isHandleExceptionMarker() = + this is MethodInsnNode && this.owner == COROUTINE_MARKER_OWNER && this.name == HANDLE_EXCEPTION_MARKER_NAME + + private fun AbstractInsnNode.isHandleExceptionMarkerArgument() = + this is MethodInsnNode && this.owner == COROUTINE_MARKER_OWNER && this.name == HANDLE_EXCEPTION_ARGUMENT_MARKER_NAME } private fun Type.fieldNameForVar(index: Int) = descriptor.first() + "$" + index diff --git a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt index f7ed1e5874e..d7ff7752cfd 100644 --- a/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt +++ b/compiler/backend/src/org/jetbrains/kotlin/codegen/coroutines/coroutineCodegenUtil.kt @@ -18,6 +18,8 @@ package org.jetbrains.kotlin.codegen.coroutines import com.intellij.openapi.project.Project import org.jetbrains.kotlin.descriptors.FunctionDescriptor +import org.jetbrains.kotlin.descriptors.SimpleFunctionDescriptor +import org.jetbrains.kotlin.psi.KtElement import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtPsiFactory import org.jetbrains.kotlin.resolve.BindingTraceContext @@ -27,16 +29,21 @@ import org.jetbrains.kotlin.resolve.calls.model.MutableDataFlowInfoForArguments import org.jetbrains.kotlin.resolve.calls.model.ResolvedCall import org.jetbrains.kotlin.resolve.calls.model.ResolvedCallImpl import org.jetbrains.kotlin.resolve.calls.smartcasts.DataFlowInfo +import org.jetbrains.kotlin.resolve.calls.tasks.ExplicitReceiverKind import org.jetbrains.kotlin.resolve.calls.tasks.TracingStrategy +import org.jetbrains.kotlin.resolve.calls.util.CallMaker import org.jetbrains.kotlin.resolve.coroutine.SUSPENSION_POINT_KEY import org.jetbrains.kotlin.types.TypeConstructorSubstitution +import org.jetbrains.kotlin.types.TypeSubstitutor import org.jetbrains.kotlin.types.typeUtil.asTypeProjection // These classes do not actually exist at runtime val CONTINUATION_METHOD_ANNOTATION_DESC = "Lkotlin/ContinuationMethod;" -const val SUSPENSION_POINT_MARKER_OWNER = "kotlin/Markers" +const val COROUTINE_MARKER_OWNER = "kotlin/coroutines/Markers" const val SUSPENSION_POINT_MARKER_NAME = "suspensionPoint" +const val HANDLE_EXCEPTION_MARKER_NAME = "handleException" +const val HANDLE_EXCEPTION_ARGUMENT_MARKER_NAME = "handleExceptionArgument" const val COROUTINE_CONTROLLER_FIELD_NAME = "controller" const val COROUTINE_LABEL_FIELD_NAME = "label" @@ -87,6 +94,42 @@ fun ResolvedCall<*>.replaceSuspensionFunctionViewWithRealDescriptor( return ResolvedCallWithRealDescriptor(newCall, thisExpression) } +data class HandleResultCallContext( + val resolvedCall: ResolvedCall<*>, + val exceptionExpression: KtExpression, + val continuationThisExpression: KtExpression +) + +fun createResolvedCallForHandleExceptionCall( + callElement: KtElement, + handleExceptionFunction: SimpleFunctionDescriptor, + coroutineLambdaDescriptor: FunctionDescriptor +): HandleResultCallContext { + val psiFactory = KtPsiFactory(callElement) + + val exceptionArgument = CallMaker.makeValueArgument(psiFactory.createExpression("exception")) + val continuationThisArgument = CallMaker.makeValueArgument(psiFactory.createExpression("this")) + + val valueArguments = listOf(exceptionArgument, continuationThisArgument) + val call = CallMaker.makeCall(callElement, null, null, null, valueArguments) + + val resolvedCall = ResolvedCallImpl( + call, + handleExceptionFunction, + coroutineLambdaDescriptor.extensionReceiverParameter!!.value, null, ExplicitReceiverKind.NO_EXPLICIT_RECEIVER, + null, DelegatingBindingTrace(BindingTraceContext().bindingContext, "Temporary trace for handleException resolution"), + TracingStrategy.EMPTY, MutableDataFlowInfoForArguments.WithoutArgumentsCheck(DataFlowInfo.EMPTY)) + + handleExceptionFunction.valueParameters.zip(valueArguments).forEach { + resolvedCall.recordValueArgument(it.first, ExpressionValueArgument(it.second)) + } + + resolvedCall.setResultingSubstitutor(TypeSubstitutor.EMPTY) + + return HandleResultCallContext( + resolvedCall, exceptionArgument.getArgumentExpression()!!, continuationThisArgument.getArgumentExpression()!!) +} + fun ResolvedCall<*>.isSuspensionPoint() = (candidateDescriptor as? FunctionDescriptor)?.let { it.isSuspend && it.getUserData(SUSPENSION_POINT_KEY) ?: false } ?: false diff --git a/compiler/testData/codegen/box/coroutines/handleException.kt b/compiler/testData/codegen/box/coroutines/handleException.kt new file mode 100644 index 00000000000..c1e5140c003 --- /dev/null +++ b/compiler/testData/codegen/box/coroutines/handleException.kt @@ -0,0 +1,75 @@ +// WITH_RUNTIME +class Controller { + var exception: Throwable? = null + val postponedActions = java.util.ArrayList<() -> Unit>() + + suspend fun suspendWithValue(v: String, x: Continuation) { + postponedActions.add { + x.resume(v) + } + } + + suspend fun suspendWithException(e: Exception, x: Continuation) { + postponedActions.add { + x.resumeWithException(e) + } + } + + operator fun handleException(t: Throwable, c: Continuation) { + exception = t + } + + fun run(c: Controller.() -> Continuation) { + c(this).resume(Unit) + while (postponedActions.isNotEmpty()) { + postponedActions[0]() + postponedActions.removeAt(0) + } + } +} + +fun builder(coroutine c: Controller.() -> Continuation) { + val controller = Controller() + controller.run(c) + + if (controller.exception?.message != "OK") { + throw RuntimeException("Unexpected result: ${controller.exception?.message}") + } +} + +fun commonThrow(t: Throwable) { + throw t +} + +fun box(): String { + builder { + throw RuntimeException("OK") + } + + builder { + commonThrow(RuntimeException("OK")) + } + + builder { + suspendWithException(RuntimeException("OK")) + } + + builder { + try { + suspendWithException(RuntimeException("fail 1")) + } catch (e: RuntimeException) { + suspendWithException(RuntimeException("OK")) + } + } + + builder { + try { + suspendWithException(Exception("OK")) + } catch (e: RuntimeException) { + suspendWithException(RuntimeException("fail 3")) + throw RuntimeException("fail 4") + } + } + + return "OK" +} diff --git a/compiler/testData/codegen/java8/box/asyncException.kt b/compiler/testData/codegen/java8/box/asyncException.kt index f11773ba909..2083d7f9397 100644 --- a/compiler/testData/codegen/java8/box/asyncException.kt +++ b/compiler/testData/codegen/java8/box/asyncException.kt @@ -25,7 +25,21 @@ fun box(): String { java.lang.Thread.sleep(1000) - return result + if (result != "OK") return "fail notOk" + + val future2 = async() { + await(exception("OK")) + "fail" + } + + try { + future2.get() + } catch (e: Exception) { + if (e.cause!!.message != "OK") return "fail message: ${e.cause!!.message}" + return "OK" + } + + return "No exception" } fun async(coroutine c: FutureController.() -> Continuation): CompletableFuture { @@ -37,17 +51,12 @@ fun async(coroutine c: FutureController.() -> Continuation): Comple class FutureController { val future = CompletableFuture() - suspend fun await(f: CompletableFuture, machine: Continuation) { f.whenComplete { value, throwable -> - try { - if (throwable == null) - machine.resume(value) - else - machine.resumeWithException(throwable) - } catch (e: Exception) { - future.completeExceptionally(e) - } + if (throwable == null) + machine.resume(value) + else + machine.resumeWithException(throwable) } } @@ -55,7 +64,7 @@ class FutureController { future.complete(value) } - fun handleException(t: Throwable, c: Continuation) { + operator fun handleException(t: Throwable, c: Continuation) { future.completeExceptionally(t) } } diff --git a/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java b/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java index 82dad63426c..0edadca2339 100644 --- a/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java +++ b/compiler/tests/org/jetbrains/kotlin/codegen/BlackBoxCodegenTestGenerated.java @@ -4147,6 +4147,12 @@ public class BlackBoxCodegenTestGenerated extends AbstractBlackBoxCodegenTest { doTest(fileName); } + @TestMetadata("handleException.kt") + public void testHandleException() throws Exception { + String fileName = KotlinTestUtils.navigationMetadata("compiler/testData/codegen/box/coroutines/handleException.kt"); + doTest(fileName); + } + @TestMetadata("illegalState.kt") public void testIllegalState() throws Exception { String fileName = KotlinTestUtils.navigationMetadata("compiler/testData/codegen/box/coroutines/illegalState.kt");