FIR: Support safe-calls new format in FIR2IR

^KT-38444 In Progress
This commit is contained in:
Denis Zharkov
2020-06-01 15:27:50 +03:00
parent 7ba1371466
commit b0b7cf4042
5 changed files with 115 additions and 90 deletions

View File

@@ -34,7 +34,13 @@ import org.jetbrains.kotlin.ir.descriptors.IrBuiltIns
import org.jetbrains.kotlin.ir.descriptors.WrappedReceiverParameterDescriptor
import org.jetbrains.kotlin.ir.expressions.IrConst
import org.jetbrains.kotlin.ir.expressions.IrConstKind
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.impl.*
import org.jetbrains.kotlin.ir.symbols.IrClassifierSymbol
import org.jetbrains.kotlin.ir.symbols.IrSymbol
import org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.symbols.*
import org.jetbrains.kotlin.ir.symbols.impl.IrClassPublicSymbolImpl
@@ -480,3 +486,45 @@ val IrType.isSamType: Boolean
val am = irClass.functions.singleOrNull { it.owner.modality == Modality.ABSTRACT }
return am != null
}
fun Fir2IrComponents.createSafeCallConstruction(
receiverVariable: IrVariable,
receiverVariableSymbol: IrValueSymbol,
expressionOnNotNull: IrExpression,
isReceiverNullable: Boolean
): IrExpression {
val startOffset = expressionOnNotNull.startOffset
val endOffset = expressionOnNotNull.endOffset
val resultType = expressionOnNotNull.type.let { if (isReceiverNullable) it.makeNullable() else it }
return IrBlockImpl(startOffset, endOffset, resultType, IrStatementOrigin.SAFE_CALL).apply {
statements += receiverVariable
statements += IrWhenImpl(startOffset, endOffset, resultType).apply {
val condition = IrCallImpl(
startOffset, endOffset, irBuiltIns.booleanType, irBuiltIns.eqeqSymbol, origin = IrStatementOrigin.EQEQ
).apply {
putValueArgument(0, IrGetValueImpl(startOffset, endOffset, receiverVariableSymbol))
putValueArgument(1, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType))
}
branches += IrBranchImpl(
condition, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType)
)
branches += IrElseBranchImpl(
IrConstImpl.boolean(startOffset, endOffset, irBuiltIns.booleanType, true),
expressionOnNotNull
)
}
}
}
fun Fir2IrComponents.createTemporaryVariableForSafeCallConstruction(
receiverExpression: IrExpression,
conversionScope: Fir2IrConversionScope
): Pair<IrVariable, IrValueSymbol> {
val receiverVariable = declarationStorage.declareTemporaryVariable(receiverExpression, "safe_receiver").apply {
parent = conversionScope.parentFromStack()
}
val variableSymbol = symbolTable.referenceValue(receiverVariable.descriptor)
return Pair(receiverVariable, variableSymbol)
}

View File

@@ -65,12 +65,20 @@ class Fir2IrConversionScope {
return klass
}
private val subjectVariableStack = mutableListOf<IrVariable>()
private val whenSubjectVariableStack = mutableListOf<IrVariable>()
private val safeCallSubjectVariableStack = mutableListOf<IrVariable>()
fun <T> withSubject(subject: IrVariable?, f: () -> T): T {
if (subject != null) subjectVariableStack += subject
fun <T> withWhenSubject(subject: IrVariable?, f: () -> T): T {
if (subject != null) whenSubjectVariableStack += subject
val result = f()
if (subject != null) subjectVariableStack.removeAt(subjectVariableStack.size - 1)
if (subject != null) whenSubjectVariableStack.removeAt(whenSubjectVariableStack.size - 1)
return result
}
fun <T> withSafeCallSubject(subject: IrVariable?, f: () -> T): T {
if (subject != null) safeCallSubjectVariableStack += subject
val result = f()
if (subject != null) safeCallSubjectVariableStack.removeAt(safeCallSubjectVariableStack.size - 1)
return result
}
@@ -98,5 +106,6 @@ class Fir2IrConversionScope {
fun lastClass(): IrClass? = classStack.lastOrNull()
fun lastSubject(): IrVariable = subjectVariableStack.last()
}
fun lastWhenSubject(): IrVariable = whenSubjectVariableStack.last()
fun lastSafeCallSubject(): IrVariable = safeCallSubjectVariableStack.last()
}

View File

@@ -6,7 +6,6 @@
package org.jetbrains.kotlin.fir.backend
import org.jetbrains.kotlin.KtNodeTypes
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.Visibilities
import org.jetbrains.kotlin.fir.*
@@ -67,7 +66,7 @@ class Fir2IrVisitor(
private val memberGenerator = ClassMemberGenerator(components, this, conversionScope, callGenerator, fakeOverrideMode)
private val operatorGenerator = OperatorExpressionGenerator(components, this, callGenerator)
private val operatorGenerator = OperatorExpressionGenerator(components, this, callGenerator, conversionScope)
private fun FirTypeRef.toIrType(): IrType = with(typeConverter) { toIrType() }
@@ -287,6 +286,34 @@ class Fir2IrVisitor(
return callGenerator.convertToIrCall(convertibleCall, convertibleCall.typeRef, explicitReceiverExpression)
}
override fun visitSafeCallExpression(safeCallExpression: FirSafeCallExpression, data: Any?): IrElement {
val explicitReceiverExpression = convertToIrExpression(safeCallExpression.receiver)
val (receiverVariable, variableSymbol) = components.createTemporaryVariableForSafeCallConstruction(
explicitReceiverExpression,
conversionScope
)
return conversionScope.withSafeCallSubject(receiverVariable) {
val afterNotNullCheck = safeCallExpression.regularQualifiedAccess.accept(this, data) as IrExpression
val isReceiverNullable = with(components.session.inferenceContext) {
safeCallExpression.receiver.typeRef.coneTypeSafe<ConeKotlinType>()?.isNullableType() == true
}
components.createSafeCallConstruction(
receiverVariable, variableSymbol, afterNotNullCheck, isReceiverNullable
)
}
}
override fun visitCheckedSafeCallSubject(checkedSafeCallSubject: FirCheckedSafeCallSubject, data: Any?): IrElement {
val lastSubjectVariable = conversionScope.lastSafeCallSubject()
return checkedSafeCallSubject.convertWithOffsets { startOffset, endOffset ->
IrGetValueImpl(startOffset, endOffset, lastSubjectVariable.type, lastSubjectVariable.symbol)
}
}
private fun FirFunctionCall.resolvedNamedFunctionSymbol(): FirNamedFunctionSymbol? {
val calleeReference = (calleeReference as? FirResolvedNamedReference) ?: return null
return calleeReference.resolvedSymbol as? FirNamedFunctionSymbol
@@ -552,7 +579,7 @@ class Fir2IrVisitor(
KtNodeTypes.POSTFIX_EXPRESSION -> IrStatementOrigin.EXCLEXCL
else -> null
}
return conversionScope.withSubject(subjectVariable) {
return conversionScope.withWhenSubject(subjectVariable) {
whenExpression.convertWithOffsets { startOffset, endOffset ->
val irWhen = IrWhenImpl(
startOffset, endOffset,
@@ -601,7 +628,7 @@ class Fir2IrVisitor(
}
override fun visitWhenSubjectExpression(whenSubjectExpression: FirWhenSubjectExpression, data: Any?): IrElement {
val lastSubjectVariable = conversionScope.lastSubject()
val lastSubjectVariable = conversionScope.lastWhenSubject()
return whenSubjectExpression.convertWithOffsets { startOffset, endOffset ->
IrGetValueImpl(startOffset, endOffset, lastSubjectVariable.type, lastSubjectVariable.symbol)
}

View File

@@ -6,7 +6,10 @@
package org.jetbrains.kotlin.fir.backend.generators
import org.jetbrains.kotlin.fir.backend.*
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.declarations.FirClass
import org.jetbrains.kotlin.fir.declarations.FirDeclarationOrigin
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
import org.jetbrains.kotlin.fir.declarations.FirValueParameter
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.FirNoReceiverExpression
import org.jetbrains.kotlin.fir.psi
@@ -15,11 +18,9 @@ import org.jetbrains.kotlin.fir.references.FirReference
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.FirSuperReference
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.resolve.calls.isExtensionFunctionType
import org.jetbrains.kotlin.fir.resolve.calls.isFunctional
import org.jetbrains.kotlin.fir.resolve.inference.isBuiltinFunctionalType
import org.jetbrains.kotlin.fir.resolve.toSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
@@ -99,61 +100,6 @@ internal class CallAndReferenceGenerator(
}.applyTypeArguments(callableReferenceAccess).applyReceivers(callableReferenceAccess, explicitReceiverExpression)
}
fun convertToIrCall(
qualifiedAccess: FirQualifiedAccess,
typeRef: FirTypeRef,
explicitReceiverExpression: IrExpression?
): IrExpression {
val explicitReceiver = qualifiedAccess.explicitReceiver
if (!qualifiedAccess.safe || explicitReceiver == null) {
return convertToUnsafeIrCall(qualifiedAccess, typeRef, explicitReceiverExpression)
}
return qualifiedAccess.convertWithOffsets { startOffset, endOffset ->
val callableSymbol = (qualifiedAccess.calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol<*>
val typeShouldBeNotNull = callableSymbol?.fir?.returnTypeRef?.coneTypeSafe<ConeKotlinType>()?.isNullable == false
val unsafeIrCall =
convertToUnsafeIrCall(qualifiedAccess, typeRef, explicitReceiverExpression, makeNotNull = typeShouldBeNotNull)
convertToSafeIrCall(
unsafeIrCall,
explicitReceiverExpression!!,
isDispatch = explicitReceiver == qualifiedAccess.dispatchReceiver
)
}
}
internal fun convertToSafeIrCall(call: IrExpression, explicitReceiverExpression: IrExpression, isDispatch: Boolean): IrExpression {
val startOffset = call.startOffset
val endOffset = call.endOffset
val receiverVariable = declarationStorage.declareTemporaryVariable(explicitReceiverExpression, "safe_receiver").apply {
parent = conversionScope.parentFromStack()
}
val variableSymbol = symbolTable.referenceValue(receiverVariable.descriptor)
val resultType = call.type.makeNullable()
return IrBlockImpl(startOffset, endOffset, resultType, IrStatementOrigin.SAFE_CALL).apply {
statements += receiverVariable
statements += IrWhenImpl(startOffset, endOffset, resultType).apply {
val condition = IrCallImpl(
startOffset, endOffset, irBuiltIns.booleanType, irBuiltIns.eqeqSymbol, origin = IrStatementOrigin.EQEQ
).apply {
putValueArgument(0, IrGetValueImpl(startOffset, endOffset, variableSymbol))
putValueArgument(1, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType))
}
branches += IrBranchImpl(
condition, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType)
)
val newReceiver = IrGetValueImpl(startOffset, endOffset, variableSymbol)
val replacedCall = call.replaceReceiver(newReceiver, isDispatch)
branches += IrElseBranchImpl(
IrConstImpl.boolean(startOffset, endOffset, irBuiltIns.booleanType, true),
replacedCall
)
}
}
}
private fun FirQualifiedAccess.tryConvertToSamConstructorCall(type: IrType): IrTypeOperatorCall? {
val calleeReference = calleeReference as? FirResolvedNamedReference ?: return null
val fir = calleeReference.resolvedSymbol.fir
@@ -167,10 +113,14 @@ internal class CallAndReferenceGenerator(
return null
}
private fun convertToUnsafeIrCall(
qualifiedAccess: FirQualifiedAccess, typeRef: FirTypeRef, explicitReceiverExpression: IrExpression?, makeNotNull: Boolean = false
fun convertToIrCall(
qualifiedAccess: FirQualifiedAccess,
typeRef: FirTypeRef,
explicitReceiverExpression: IrExpression?
): IrExpression {
val type = typeRef.toIrType().let { if (makeNotNull) it.makeNotNull() else it }
require(!qualifiedAccess.safe)
val type = typeRef.toIrType()
val samConstructorCall = qualifiedAccess.tryConvertToSamConstructorCall(type)
if (samConstructorCall != null) return samConstructorCall
@@ -561,22 +511,6 @@ internal class CallAndReferenceGenerator(
}
}
private fun IrExpression.replaceReceiver(newReceiver: IrExpression, isDispatch: Boolean): IrExpression {
when (this) {
is IrCallImpl -> {
if (!isDispatch) {
extensionReceiver = newReceiver
} else {
dispatchReceiver = newReceiver
}
}
is IrFieldExpressionBase -> {
receiver = newReceiver
}
}
return this
}
private fun generateErrorCallExpression(
startOffset: Int,
endOffset: Int,

View File

@@ -16,6 +16,7 @@ import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.IrStatementOriginImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl
import org.jetbrains.kotlin.ir.symbols.IrClassifierSymbol
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
import org.jetbrains.kotlin.ir.types.classifierOrFail
@@ -24,7 +25,8 @@ import org.jetbrains.kotlin.ir.util.getSimpleFunction
internal class OperatorExpressionGenerator(
private val components: Fir2IrComponents,
private val visitor: Fir2IrVisitor,
private val callGenerator: CallAndReferenceGenerator
private val callGenerator: CallAndReferenceGenerator,
private val conversionScope: Fir2IrConversionScope
) : Fir2IrComponents by components {
fun convertComparisonExpression(comparisonExpression: FirComparisonExpression): IrExpression {
@@ -158,9 +160,14 @@ internal class OperatorExpressionGenerator(
it.dispatchReceiver = dispatchReceiver
}
return if (operandType.isNullable) {
callGenerator.convertToSafeIrCall(unsafeIrCall, dispatchReceiver, isDispatch = true)
val (receiverVariable, receiverVariableSymbol) =
components.createTemporaryVariableForSafeCallConstruction(dispatchReceiver, conversionScope)
unsafeIrCall.dispatchReceiver = IrGetValueImpl(startOffset, endOffset, receiverVariableSymbol)
components.createSafeCallConstruction(receiverVariable, receiverVariableSymbol, unsafeIrCall, isReceiverNullable = true)
} else {
unsafeIrCall
}
}
}
}