mirror of
https://github.com/jlengrand/kotlin.git
synced 2026-04-21 08:31:30 +00:00
FIR: Support safe-calls new format in FIR2IR
^KT-38444 In Progress
This commit is contained in:
@@ -34,7 +34,13 @@ import org.jetbrains.kotlin.ir.descriptors.IrBuiltIns
|
||||
import org.jetbrains.kotlin.ir.descriptors.WrappedReceiverParameterDescriptor
|
||||
import org.jetbrains.kotlin.ir.expressions.IrConst
|
||||
import org.jetbrains.kotlin.ir.expressions.IrConstKind
|
||||
import org.jetbrains.kotlin.ir.expressions.IrExpression
|
||||
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
|
||||
import org.jetbrains.kotlin.ir.expressions.impl.*
|
||||
import org.jetbrains.kotlin.ir.symbols.IrClassifierSymbol
|
||||
import org.jetbrains.kotlin.ir.symbols.IrSymbol
|
||||
import org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
|
||||
import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
|
||||
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
|
||||
import org.jetbrains.kotlin.ir.symbols.*
|
||||
import org.jetbrains.kotlin.ir.symbols.impl.IrClassPublicSymbolImpl
|
||||
@@ -480,3 +486,45 @@ val IrType.isSamType: Boolean
|
||||
val am = irClass.functions.singleOrNull { it.owner.modality == Modality.ABSTRACT }
|
||||
return am != null
|
||||
}
|
||||
|
||||
fun Fir2IrComponents.createSafeCallConstruction(
|
||||
receiverVariable: IrVariable,
|
||||
receiverVariableSymbol: IrValueSymbol,
|
||||
expressionOnNotNull: IrExpression,
|
||||
isReceiverNullable: Boolean
|
||||
): IrExpression {
|
||||
val startOffset = expressionOnNotNull.startOffset
|
||||
val endOffset = expressionOnNotNull.endOffset
|
||||
|
||||
val resultType = expressionOnNotNull.type.let { if (isReceiverNullable) it.makeNullable() else it }
|
||||
return IrBlockImpl(startOffset, endOffset, resultType, IrStatementOrigin.SAFE_CALL).apply {
|
||||
statements += receiverVariable
|
||||
statements += IrWhenImpl(startOffset, endOffset, resultType).apply {
|
||||
val condition = IrCallImpl(
|
||||
startOffset, endOffset, irBuiltIns.booleanType, irBuiltIns.eqeqSymbol, origin = IrStatementOrigin.EQEQ
|
||||
).apply {
|
||||
putValueArgument(0, IrGetValueImpl(startOffset, endOffset, receiverVariableSymbol))
|
||||
putValueArgument(1, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType))
|
||||
}
|
||||
branches += IrBranchImpl(
|
||||
condition, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType)
|
||||
)
|
||||
branches += IrElseBranchImpl(
|
||||
IrConstImpl.boolean(startOffset, endOffset, irBuiltIns.booleanType, true),
|
||||
expressionOnNotNull
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun Fir2IrComponents.createTemporaryVariableForSafeCallConstruction(
|
||||
receiverExpression: IrExpression,
|
||||
conversionScope: Fir2IrConversionScope
|
||||
): Pair<IrVariable, IrValueSymbol> {
|
||||
val receiverVariable = declarationStorage.declareTemporaryVariable(receiverExpression, "safe_receiver").apply {
|
||||
parent = conversionScope.parentFromStack()
|
||||
}
|
||||
val variableSymbol = symbolTable.referenceValue(receiverVariable.descriptor)
|
||||
|
||||
return Pair(receiverVariable, variableSymbol)
|
||||
}
|
||||
|
||||
@@ -65,12 +65,20 @@ class Fir2IrConversionScope {
|
||||
return klass
|
||||
}
|
||||
|
||||
private val subjectVariableStack = mutableListOf<IrVariable>()
|
||||
private val whenSubjectVariableStack = mutableListOf<IrVariable>()
|
||||
private val safeCallSubjectVariableStack = mutableListOf<IrVariable>()
|
||||
|
||||
fun <T> withSubject(subject: IrVariable?, f: () -> T): T {
|
||||
if (subject != null) subjectVariableStack += subject
|
||||
fun <T> withWhenSubject(subject: IrVariable?, f: () -> T): T {
|
||||
if (subject != null) whenSubjectVariableStack += subject
|
||||
val result = f()
|
||||
if (subject != null) subjectVariableStack.removeAt(subjectVariableStack.size - 1)
|
||||
if (subject != null) whenSubjectVariableStack.removeAt(whenSubjectVariableStack.size - 1)
|
||||
return result
|
||||
}
|
||||
|
||||
fun <T> withSafeCallSubject(subject: IrVariable?, f: () -> T): T {
|
||||
if (subject != null) safeCallSubjectVariableStack += subject
|
||||
val result = f()
|
||||
if (subject != null) safeCallSubjectVariableStack.removeAt(safeCallSubjectVariableStack.size - 1)
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -98,5 +106,6 @@ class Fir2IrConversionScope {
|
||||
|
||||
fun lastClass(): IrClass? = classStack.lastOrNull()
|
||||
|
||||
fun lastSubject(): IrVariable = subjectVariableStack.last()
|
||||
}
|
||||
fun lastWhenSubject(): IrVariable = whenSubjectVariableStack.last()
|
||||
fun lastSafeCallSubject(): IrVariable = safeCallSubjectVariableStack.last()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package org.jetbrains.kotlin.fir.backend
|
||||
|
||||
import org.jetbrains.kotlin.KtNodeTypes
|
||||
import org.jetbrains.kotlin.descriptors.ClassConstructorDescriptor
|
||||
import org.jetbrains.kotlin.descriptors.ClassKind
|
||||
import org.jetbrains.kotlin.descriptors.Visibilities
|
||||
import org.jetbrains.kotlin.fir.*
|
||||
@@ -67,7 +66,7 @@ class Fir2IrVisitor(
|
||||
|
||||
private val memberGenerator = ClassMemberGenerator(components, this, conversionScope, callGenerator, fakeOverrideMode)
|
||||
|
||||
private val operatorGenerator = OperatorExpressionGenerator(components, this, callGenerator)
|
||||
private val operatorGenerator = OperatorExpressionGenerator(components, this, callGenerator, conversionScope)
|
||||
|
||||
private fun FirTypeRef.toIrType(): IrType = with(typeConverter) { toIrType() }
|
||||
|
||||
@@ -287,6 +286,34 @@ class Fir2IrVisitor(
|
||||
return callGenerator.convertToIrCall(convertibleCall, convertibleCall.typeRef, explicitReceiverExpression)
|
||||
}
|
||||
|
||||
override fun visitSafeCallExpression(safeCallExpression: FirSafeCallExpression, data: Any?): IrElement {
|
||||
val explicitReceiverExpression = convertToIrExpression(safeCallExpression.receiver)
|
||||
|
||||
val (receiverVariable, variableSymbol) = components.createTemporaryVariableForSafeCallConstruction(
|
||||
explicitReceiverExpression,
|
||||
conversionScope
|
||||
)
|
||||
|
||||
return conversionScope.withSafeCallSubject(receiverVariable) {
|
||||
val afterNotNullCheck = safeCallExpression.regularQualifiedAccess.accept(this, data) as IrExpression
|
||||
|
||||
val isReceiverNullable = with(components.session.inferenceContext) {
|
||||
safeCallExpression.receiver.typeRef.coneTypeSafe<ConeKotlinType>()?.isNullableType() == true
|
||||
}
|
||||
|
||||
components.createSafeCallConstruction(
|
||||
receiverVariable, variableSymbol, afterNotNullCheck, isReceiverNullable
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun visitCheckedSafeCallSubject(checkedSafeCallSubject: FirCheckedSafeCallSubject, data: Any?): IrElement {
|
||||
val lastSubjectVariable = conversionScope.lastSafeCallSubject()
|
||||
return checkedSafeCallSubject.convertWithOffsets { startOffset, endOffset ->
|
||||
IrGetValueImpl(startOffset, endOffset, lastSubjectVariable.type, lastSubjectVariable.symbol)
|
||||
}
|
||||
}
|
||||
|
||||
private fun FirFunctionCall.resolvedNamedFunctionSymbol(): FirNamedFunctionSymbol? {
|
||||
val calleeReference = (calleeReference as? FirResolvedNamedReference) ?: return null
|
||||
return calleeReference.resolvedSymbol as? FirNamedFunctionSymbol
|
||||
@@ -552,7 +579,7 @@ class Fir2IrVisitor(
|
||||
KtNodeTypes.POSTFIX_EXPRESSION -> IrStatementOrigin.EXCLEXCL
|
||||
else -> null
|
||||
}
|
||||
return conversionScope.withSubject(subjectVariable) {
|
||||
return conversionScope.withWhenSubject(subjectVariable) {
|
||||
whenExpression.convertWithOffsets { startOffset, endOffset ->
|
||||
val irWhen = IrWhenImpl(
|
||||
startOffset, endOffset,
|
||||
@@ -601,7 +628,7 @@ class Fir2IrVisitor(
|
||||
}
|
||||
|
||||
override fun visitWhenSubjectExpression(whenSubjectExpression: FirWhenSubjectExpression, data: Any?): IrElement {
|
||||
val lastSubjectVariable = conversionScope.lastSubject()
|
||||
val lastSubjectVariable = conversionScope.lastWhenSubject()
|
||||
return whenSubjectExpression.convertWithOffsets { startOffset, endOffset ->
|
||||
IrGetValueImpl(startOffset, endOffset, lastSubjectVariable.type, lastSubjectVariable.symbol)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,10 @@
|
||||
package org.jetbrains.kotlin.fir.backend.generators
|
||||
|
||||
import org.jetbrains.kotlin.fir.backend.*
|
||||
import org.jetbrains.kotlin.fir.declarations.*
|
||||
import org.jetbrains.kotlin.fir.declarations.FirClass
|
||||
import org.jetbrains.kotlin.fir.declarations.FirDeclarationOrigin
|
||||
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
|
||||
import org.jetbrains.kotlin.fir.declarations.FirValueParameter
|
||||
import org.jetbrains.kotlin.fir.expressions.*
|
||||
import org.jetbrains.kotlin.fir.expressions.impl.FirNoReceiverExpression
|
||||
import org.jetbrains.kotlin.fir.psi
|
||||
@@ -15,11 +18,9 @@ import org.jetbrains.kotlin.fir.references.FirReference
|
||||
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
|
||||
import org.jetbrains.kotlin.fir.references.FirSuperReference
|
||||
import org.jetbrains.kotlin.fir.render
|
||||
import org.jetbrains.kotlin.fir.resolve.calls.isExtensionFunctionType
|
||||
import org.jetbrains.kotlin.fir.resolve.calls.isFunctional
|
||||
import org.jetbrains.kotlin.fir.resolve.inference.isBuiltinFunctionalType
|
||||
import org.jetbrains.kotlin.fir.resolve.toSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
|
||||
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
|
||||
@@ -99,61 +100,6 @@ internal class CallAndReferenceGenerator(
|
||||
}.applyTypeArguments(callableReferenceAccess).applyReceivers(callableReferenceAccess, explicitReceiverExpression)
|
||||
}
|
||||
|
||||
fun convertToIrCall(
|
||||
qualifiedAccess: FirQualifiedAccess,
|
||||
typeRef: FirTypeRef,
|
||||
explicitReceiverExpression: IrExpression?
|
||||
): IrExpression {
|
||||
val explicitReceiver = qualifiedAccess.explicitReceiver
|
||||
if (!qualifiedAccess.safe || explicitReceiver == null) {
|
||||
return convertToUnsafeIrCall(qualifiedAccess, typeRef, explicitReceiverExpression)
|
||||
}
|
||||
return qualifiedAccess.convertWithOffsets { startOffset, endOffset ->
|
||||
val callableSymbol = (qualifiedAccess.calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol<*>
|
||||
val typeShouldBeNotNull = callableSymbol?.fir?.returnTypeRef?.coneTypeSafe<ConeKotlinType>()?.isNullable == false
|
||||
val unsafeIrCall =
|
||||
convertToUnsafeIrCall(qualifiedAccess, typeRef, explicitReceiverExpression, makeNotNull = typeShouldBeNotNull)
|
||||
|
||||
convertToSafeIrCall(
|
||||
unsafeIrCall,
|
||||
explicitReceiverExpression!!,
|
||||
isDispatch = explicitReceiver == qualifiedAccess.dispatchReceiver
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun convertToSafeIrCall(call: IrExpression, explicitReceiverExpression: IrExpression, isDispatch: Boolean): IrExpression {
|
||||
val startOffset = call.startOffset
|
||||
val endOffset = call.endOffset
|
||||
val receiverVariable = declarationStorage.declareTemporaryVariable(explicitReceiverExpression, "safe_receiver").apply {
|
||||
parent = conversionScope.parentFromStack()
|
||||
}
|
||||
val variableSymbol = symbolTable.referenceValue(receiverVariable.descriptor)
|
||||
|
||||
val resultType = call.type.makeNullable()
|
||||
return IrBlockImpl(startOffset, endOffset, resultType, IrStatementOrigin.SAFE_CALL).apply {
|
||||
statements += receiverVariable
|
||||
statements += IrWhenImpl(startOffset, endOffset, resultType).apply {
|
||||
val condition = IrCallImpl(
|
||||
startOffset, endOffset, irBuiltIns.booleanType, irBuiltIns.eqeqSymbol, origin = IrStatementOrigin.EQEQ
|
||||
).apply {
|
||||
putValueArgument(0, IrGetValueImpl(startOffset, endOffset, variableSymbol))
|
||||
putValueArgument(1, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType))
|
||||
}
|
||||
branches += IrBranchImpl(
|
||||
condition, IrConstImpl.constNull(startOffset, endOffset, irBuiltIns.nothingNType)
|
||||
)
|
||||
val newReceiver = IrGetValueImpl(startOffset, endOffset, variableSymbol)
|
||||
val replacedCall = call.replaceReceiver(newReceiver, isDispatch)
|
||||
branches += IrElseBranchImpl(
|
||||
IrConstImpl.boolean(startOffset, endOffset, irBuiltIns.booleanType, true),
|
||||
replacedCall
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun FirQualifiedAccess.tryConvertToSamConstructorCall(type: IrType): IrTypeOperatorCall? {
|
||||
val calleeReference = calleeReference as? FirResolvedNamedReference ?: return null
|
||||
val fir = calleeReference.resolvedSymbol.fir
|
||||
@@ -167,10 +113,14 @@ internal class CallAndReferenceGenerator(
|
||||
return null
|
||||
}
|
||||
|
||||
private fun convertToUnsafeIrCall(
|
||||
qualifiedAccess: FirQualifiedAccess, typeRef: FirTypeRef, explicitReceiverExpression: IrExpression?, makeNotNull: Boolean = false
|
||||
fun convertToIrCall(
|
||||
qualifiedAccess: FirQualifiedAccess,
|
||||
typeRef: FirTypeRef,
|
||||
explicitReceiverExpression: IrExpression?
|
||||
): IrExpression {
|
||||
val type = typeRef.toIrType().let { if (makeNotNull) it.makeNotNull() else it }
|
||||
require(!qualifiedAccess.safe)
|
||||
|
||||
val type = typeRef.toIrType()
|
||||
val samConstructorCall = qualifiedAccess.tryConvertToSamConstructorCall(type)
|
||||
if (samConstructorCall != null) return samConstructorCall
|
||||
|
||||
@@ -561,22 +511,6 @@ internal class CallAndReferenceGenerator(
|
||||
}
|
||||
}
|
||||
|
||||
private fun IrExpression.replaceReceiver(newReceiver: IrExpression, isDispatch: Boolean): IrExpression {
|
||||
when (this) {
|
||||
is IrCallImpl -> {
|
||||
if (!isDispatch) {
|
||||
extensionReceiver = newReceiver
|
||||
} else {
|
||||
dispatchReceiver = newReceiver
|
||||
}
|
||||
}
|
||||
is IrFieldExpressionBase -> {
|
||||
receiver = newReceiver
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
private fun generateErrorCallExpression(
|
||||
startOffset: Int,
|
||||
endOffset: Int,
|
||||
|
||||
@@ -16,6 +16,7 @@ import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
|
||||
import org.jetbrains.kotlin.ir.expressions.IrStatementOriginImpl
|
||||
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
|
||||
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
|
||||
import org.jetbrains.kotlin.ir.expressions.impl.IrGetValueImpl
|
||||
import org.jetbrains.kotlin.ir.symbols.IrClassifierSymbol
|
||||
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
|
||||
import org.jetbrains.kotlin.ir.types.classifierOrFail
|
||||
@@ -24,7 +25,8 @@ import org.jetbrains.kotlin.ir.util.getSimpleFunction
|
||||
internal class OperatorExpressionGenerator(
|
||||
private val components: Fir2IrComponents,
|
||||
private val visitor: Fir2IrVisitor,
|
||||
private val callGenerator: CallAndReferenceGenerator
|
||||
private val callGenerator: CallAndReferenceGenerator,
|
||||
private val conversionScope: Fir2IrConversionScope
|
||||
) : Fir2IrComponents by components {
|
||||
|
||||
fun convertComparisonExpression(comparisonExpression: FirComparisonExpression): IrExpression {
|
||||
@@ -158,9 +160,14 @@ internal class OperatorExpressionGenerator(
|
||||
it.dispatchReceiver = dispatchReceiver
|
||||
}
|
||||
return if (operandType.isNullable) {
|
||||
callGenerator.convertToSafeIrCall(unsafeIrCall, dispatchReceiver, isDispatch = true)
|
||||
val (receiverVariable, receiverVariableSymbol) =
|
||||
components.createTemporaryVariableForSafeCallConstruction(dispatchReceiver, conversionScope)
|
||||
|
||||
unsafeIrCall.dispatchReceiver = IrGetValueImpl(startOffset, endOffset, receiverVariableSymbol)
|
||||
|
||||
components.createSafeCallConstruction(receiverVariable, receiverVariableSymbol, unsafeIrCall, isReceiverNullable = true)
|
||||
} else {
|
||||
unsafeIrCall
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user