From a3758f9fa3cba9fc64fc59ecd141305cf5753548 Mon Sep 17 00:00:00 2001 From: Alexey Sedunov Date: Mon, 28 Jul 2014 15:06:04 +0400 Subject: [PATCH] PSI Pattern Matching: Implement expression unifier --- .../com/intellij/codeInsight/annotations.xml | 4 + annotations/com/intellij/lang/annotations.xml | 3 + .../psi/impl/source/tree/annotations.xml | 8 + .../jet/lang/psi/psiUtil/jetPsiUtil.kt | 6 + .../org/jetbrains/jet/lang/types/typeUtils.kt | 23 + .../util/psi/patternMatching/JetPsiRange.kt | 104 +++ .../util/psi/patternMatching/JetPsiUnifier.kt | 813 ++++++++++++++++++ 7 files changed, 961 insertions(+) create mode 100644 annotations/com/intellij/psi/impl/source/tree/annotations.xml create mode 100644 core/descriptors/src/org/jetbrains/jet/lang/types/typeUtils.kt create mode 100644 idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiRange.kt create mode 100644 idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiUnifier.kt diff --git a/annotations/com/intellij/codeInsight/annotations.xml b/annotations/com/intellij/codeInsight/annotations.xml index 24aa5b802c6..bf8abfb12e8 100644 --- a/annotations/com/intellij/codeInsight/annotations.xml +++ b/annotations/com/intellij/codeInsight/annotations.xml @@ -6,6 +6,10 @@ name='com.intellij.codeInsight.NullableNotNullManager com.intellij.codeInsight.NullableNotNullManager getInstance(com.intellij.openapi.project.Project)'> + + + diff --git a/annotations/com/intellij/lang/annotations.xml b/annotations/com/intellij/lang/annotations.xml index ebed995bb20..4f2bde8c2f6 100644 --- a/annotations/com/intellij/lang/annotations.xml +++ b/annotations/com/intellij/lang/annotations.xml @@ -1,4 +1,7 @@ + + + diff --git a/annotations/com/intellij/psi/impl/source/tree/annotations.xml b/annotations/com/intellij/psi/impl/source/tree/annotations.xml new file mode 100644 index 00000000000..9083f1004d8 --- /dev/null +++ b/annotations/com/intellij/psi/impl/source/tree/annotations.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/compiler/frontend/src/org/jetbrains/jet/lang/psi/psiUtil/jetPsiUtil.kt b/compiler/frontend/src/org/jetbrains/jet/lang/psi/psiUtil/jetPsiUtil.kt index 70d89f65276..f665868c591 100644 --- a/compiler/frontend/src/org/jetbrains/jet/lang/psi/psiUtil/jetPsiUtil.kt +++ b/compiler/frontend/src/org/jetbrains/jet/lang/psi/psiUtil/jetPsiUtil.kt @@ -335,3 +335,9 @@ public fun JetExpression.isFunctionLiteralOutsideParentheses(): Boolean { else -> false } } + +public fun PsiElement.siblings(forward: Boolean = true, withItself: Boolean = true): Stream { + val stepFun = if (forward) { (e: PsiElement) -> e.getNextSibling() } else { (e: PsiElement) -> e.getPrevSibling() } + val stream = stream(this, stepFun) + return if (withItself) stream else stream.drop(1) +} \ No newline at end of file diff --git a/core/descriptors/src/org/jetbrains/jet/lang/types/typeUtils.kt b/core/descriptors/src/org/jetbrains/jet/lang/types/typeUtils.kt new file mode 100644 index 00000000000..5ad8c49f1f5 --- /dev/null +++ b/core/descriptors/src/org/jetbrains/jet/lang/types/typeUtils.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2010-2014 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.jet.lang.types + +import org.jetbrains.jet.lang.resolve.name.FqNameUnsafe +import org.jetbrains.jet.lang.resolve.DescriptorUtils + +fun JetType.fqName(): FqNameUnsafe? = + getConstructor().getDeclarationDescriptor()?.let { DescriptorUtils.getFqName(it) } diff --git a/idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiRange.kt b/idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiRange.kt new file mode 100644 index 00000000000..cc8de5fff9e --- /dev/null +++ b/idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiRange.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2010-2014 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.jet.plugin.util.psi.patternMatching + +import com.intellij.psi.PsiElement +import org.jetbrains.jet.lang.psi.JetTreeVisitorVoid +import org.jetbrains.jet.lang.psi.JetElement +import java.util.ArrayList +import com.intellij.psi.PsiWhiteSpace +import com.intellij.psi.PsiComment +import java.util.Collections +import org.jetbrains.jet.lang.psi.psiUtil.siblings +import com.intellij.openapi.util.TextRange + +private val SIGNIFICANT_FILTER = { (e: PsiElement) -> e !is PsiWhiteSpace && e !is PsiComment && e.getTextLength() > 0 } + +public trait JetPsiRange { + public object Empty : JetPsiRange { + override val elements: List get() = Collections.emptyList() + + override fun getTextRange(): TextRange = TextRange.EMPTY_RANGE + } + + public class ListRange(override val elements: List): JetPsiRange { + val startElement: PsiElement = elements.first() + val endElement: PsiElement = elements.last() + + override fun getTextRange(): TextRange { + val startRange = startElement.getTextRange() + val endRange = endElement.getTextRange() + if (startRange == null || endRange == null) return TextRange.EMPTY_RANGE + + return TextRange(startRange.getStartOffset(), endRange.getEndOffset()) + } + } + + public class Match(val range: JetPsiRange, val result: UnificationResult.Matched) + + val elements: List + + fun getTextRange(): TextRange + + fun isValid(): Boolean = elements.all { it.isValid() } + + val empty: Boolean get() = this is Empty + + fun contains(element: PsiElement): Boolean = getTextRange().contains(element.getTextRange() ?: TextRange.EMPTY_RANGE) + + fun match(scope: PsiElement, unifier: JetPsiUnifier): List { + val elements = elements.filter(SIGNIFICANT_FILTER) + if (elements.empty) return Collections.emptyList() + + val matches = ArrayList() + scope.accept( + object: JetTreeVisitorVoid() { + private fun processElement(element: JetElement): Boolean { + val candidates = element + .siblings() + .filter(SIGNIFICANT_FILTER) + .take(elements.size) + .toList() + if (candidates.size != elements.size) return false + + val range = candidates.toRange() + + val result = unifier.unify(range, this@JetPsiRange) + if (result is UnificationResult.Matched) { + matches.add(Match(range, result)) + return true + } + + return false + } + + override fun visitJetElement(element: JetElement) { + if (!processElement(element)) { + super.visitJetElement(element) + } + } + } + ) + return matches + } +} + +public fun List.toRange(significantOnly: Boolean = true): JetPsiRange { + return if (empty) JetPsiRange.Empty else JetPsiRange.ListRange(if (significantOnly) filter(SIGNIFICANT_FILTER) else this) +} + +public fun PsiElement?.toRange(): JetPsiRange = this?.let { JetPsiRange.ListRange(Collections.singletonList(it)) } ?: JetPsiRange.Empty \ No newline at end of file diff --git a/idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiUnifier.kt b/idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiUnifier.kt new file mode 100644 index 00000000000..80cb3f48582 --- /dev/null +++ b/idea/src/org/jetbrains/jet/plugin/util/psi/patternMatching/JetPsiUnifier.kt @@ -0,0 +1,813 @@ +/* + * Copyright 2010-2014 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jetbrains.jet.plugin.util.psi.patternMatching + +import org.jetbrains.jet.lang.descriptors.DeclarationDescriptor +import org.jetbrains.jet.plugin.util.psi.patternMatching.UnificationResult.* +import org.jetbrains.jet.plugin.util.psi.patternMatching.UnificationResult.Status.* +import java.util.HashMap +import org.jetbrains.jet.lang.psi.JetExpression +import org.jetbrains.jet.lang.psi.JetPsiUtil +import com.intellij.psi.PsiElement +import org.jetbrains.jet.lang.resolve.BindingContext +import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCall +import org.jetbrains.jet.lang.types.JetType +import com.intellij.util.containers.ContainerUtil +import org.jetbrains.jet.lang.types.checker.JetTypeChecker +import java.util.Collections +import org.jetbrains.jet.lang.psi.JetReferenceExpression +import org.jetbrains.jet.lang.psi.Call +import org.jetbrains.jet.lang.resolve.scopes.receivers.ReceiverValue +import org.jetbrains.jet.lang.resolve.scopes.receivers.ExpressionReceiver +import com.intellij.psi.impl.source.tree.LeafPsiElement +import org.jetbrains.jet.lang.psi.JetElement +import org.jetbrains.jet.lang.psi.JetTypeReference +import org.jetbrains.jet.lang.types.TypeUtils +import org.jetbrains.jet.lang.psi.ValueArgument +import org.jetbrains.jet.lang.resolve.calls.tasks.ExplicitReceiverKind +import org.jetbrains.jet.lang.psi.JetIfExpression +import org.jetbrains.jet.lang.psi.JetUnaryExpression +import org.jetbrains.jet.lexer.JetTokens +import org.jetbrains.jet.lang.psi.JetBinaryExpression +import org.jetbrains.jet.lang.psi.JetConstantExpression +import org.jetbrains.jet.lang.resolve.calls.model.VariableAsFunctionResolvedCall +import org.jetbrains.jet.lang.psi.JetSimpleNameExpression +import org.jetbrains.jet.lang.psi.JetArrayAccessExpression +import org.jetbrains.jet.lexer.JetToken +import org.jetbrains.jet.lang.types.expressions.OperatorConventions +import org.jetbrains.jet.lang.psi.JetLabelReferenceExpression +import org.jetbrains.jet.lang.resolve.calls.callUtil.getCall +import org.jetbrains.jet.lang.resolve.calls.callUtil.getResolvedCall +import org.jetbrains.jet.lang.psi.JetDeclaration +import org.jetbrains.jet.lang.psi.JetNamedDeclaration +import org.jetbrains.jet.lang.types.ErrorUtils +import com.intellij.lang.ASTNode +import com.intellij.util.containers.MultiMap +import org.jetbrains.jet.plugin.project.AnalyzerFacadeWithCache +import org.jetbrains.jet.lang.psi.JetCallableReferenceExpression +import org.jetbrains.jet.lang.resolve.scopes.receivers.ThisReceiver +import org.jetbrains.jet.lang.psi.JetThisExpression +import org.jetbrains.jet.lang.psi.JetStringTemplateEntryWithExpression +import org.jetbrains.jet.plugin.util.psi.patternMatching.JetPsiRange.Empty +import org.jetbrains.jet.lang.psi.JetFunctionLiteral +import org.jetbrains.jet.lang.descriptors.impl.AnonymousFunctionDescriptor +import org.jetbrains.jet.lang.descriptors.FunctionDescriptor +import org.jetbrains.jet.lang.psi.JetMultiDeclaration +import org.jetbrains.jet.lang.psi.JetFunction +import org.jetbrains.jet.lang.psi.JetClassBody +import org.jetbrains.jet.lang.descriptors.CallableDescriptor +import org.jetbrains.jet.lang.psi.JetDeclarationWithBody +import org.jetbrains.jet.lang.psi.JetProperty +import org.jetbrains.jet.lang.descriptors.ReceiverParameterDescriptor +import org.jetbrains.jet.lang.psi.JetWithExpressionInitializer +import org.jetbrains.jet.lang.psi.JetPropertyAccessor +import org.jetbrains.jet.lang.psi.JetMultiDeclarationEntry +import org.jetbrains.jet.lang.psi.JetParameter +import org.jetbrains.jet.lang.descriptors.ClassDescriptor +import org.jetbrains.jet.lang.psi.JetClassOrObject +import org.jetbrains.jet.lang.psi.JetCallableDeclaration +import org.jetbrains.jet.lang.psi.JetClassObject +import org.jetbrains.jet.lang.descriptors.ClassKind +import org.jetbrains.jet.lang.psi.JetTypeParameter +import org.jetbrains.jet.lang.descriptors.TypeParameterDescriptor +import org.jetbrains.jet.lang.types.fqName +import org.jetbrains.jet.renderer.DescriptorRenderer +import org.jetbrains.jet.plugin.codeInsight.DescriptorToDeclarationUtil +import org.jetbrains.jet.lang.resolve.DescriptorToSourceUtils +import org.jetbrains.jet.lang.psi.JetClass +import org.jetbrains.jet.lang.descriptors.VariableDescriptor +import org.jetbrains.jet.lang.psi.JetFile +import org.jetbrains.jet.lang.psi.JetClassInitializer +import java.util.ArrayList +import org.jetbrains.jet.lang.psi.JetTypeParameterListOwner +import org.jetbrains.jet.lang.psi.doNotAnalyze +import org.jetbrains.jet.lang.psi.JetDelegatorToSuperClass +import org.jetbrains.jet.lang.psi.JetDelegationSpecifier + +public trait UnificationResult { + public enum class Status { + MATCHED { + override fun and(other: Status): Status = other + } + + UNMATCHED { + override fun and(other: Status): Status = this + } + + public abstract fun and(other: Status): Status + } + + object Unmatched : UnificationResult { + override val status: Status get() = UNMATCHED + } + + public class Matched(val substitution: Map): UnificationResult { + override val status: Status get() = MATCHED + } + + val status: Status + val matched: Boolean get() = status != UNMATCHED +} + +public class UnifierParameter( + val descriptor: DeclarationDescriptor, + val expectedType: JetType +) + +public class JetPsiUnifier( + parameters: Collection = Collections.emptySet() +) { + class object { + val DEFAULT = JetPsiUnifier() + } + + private inner class Context( + val originalTarget: JetPsiRange, + val originalPattern: JetPsiRange + ) { + val substitution = HashMap() + val declarationPatternsToTargets = MultiMap() + var checkEquivalence: Boolean = false + + private fun matchDescriptors(d1: DeclarationDescriptor?, d2: DeclarationDescriptor?): Boolean { + if (d1 == d2 || d2 in declarationPatternsToTargets[d1] || d1 in declarationPatternsToTargets[d2]) return true + if (d1 == null || d2 == null) return false + + val decl1 = DescriptorToSourceUtils.descriptorToDeclaration(d1) as? JetDeclaration + val decl2 = DescriptorToSourceUtils.descriptorToDeclaration(d2) as? JetDeclaration + if (decl1 == null || decl2 == null) return false + + if ((decl1 in originalTarget && decl2 in originalPattern) || (decl2 in originalTarget && decl1 in originalPattern)) { + return matchDeclarations(decl1, decl2, d1, d2) == MATCHED + } + + return false + } + + private fun matchReceivers(rv1: ReceiverValue, rv2: ReceiverValue): Boolean { + return when { + rv1 is ExpressionReceiver && rv2 is ExpressionReceiver -> + doUnify(rv1.getExpression(), rv2.getExpression()) == MATCHED + + rv1 is ThisReceiver && rv2 is ThisReceiver -> + matchDescriptors(rv1.getDeclarationDescriptor(), rv2.getDeclarationDescriptor()) + + else -> + rv1 == rv2 + } + } + + private fun matchCalls(call1: Call, call2: Call): Boolean { + return matchReceivers(call1.getExplicitReceiver(), call2.getExplicitReceiver()) && + matchReceivers(call1.getThisObject(), call2.getThisObject()) + } + + private fun matchArguments(arg1: ValueArgument, arg2: ValueArgument): Status { + return when { + arg1.isExternal() != arg2.isExternal() -> + UNMATCHED + + (arg1.getSpreadElement() == null) != (arg2.getSpreadElement() == null) -> + UNMATCHED + + else -> + doUnify(arg1.getArgumentExpression(), arg2.getArgumentExpression()) + } + } + + private fun matchResolvedCalls(rc1: ResolvedCall<*>, rc2: ResolvedCall<*>): Status? { + fun checkSpecialOperations(): Boolean { + val op1 = (rc1.getCall().getCalleeExpression() as? JetSimpleNameExpression)?.getReferencedNameElementType() + val op2 = (rc2.getCall().getCalleeExpression() as? JetSimpleNameExpression)?.getReferencedNameElementType() + + return when { + op1 == op2 -> + true + op1 == JetTokens.NOT_IN, op2 == JetTokens.NOT_IN -> + false + op1 == JetTokens.EXCLEQ, op2 == JetTokens.EXCLEQ -> + false + op1 in OperatorConventions.COMPARISON_OPERATIONS, op2 in OperatorConventions.COMPARISON_OPERATIONS -> + false + else -> + true + } + } + + fun checkArguments(): Status? { + val args1 = rc1.getResultingDescriptor()?.getValueParameters()?.map { rc1.getValueArguments()[it] } ?: Collections.emptyList() + val args2 = rc2.getResultingDescriptor()?.getValueParameters()?.map { rc2.getValueArguments()[it] } ?: Collections.emptyList() + if (args1.size != args2.size) return UNMATCHED + if (rc1.getCall().getValueArguments().size != args1.size || rc2.getCall().getValueArguments().size != args2.size) return null + + return (args1.stream() zip args2.stream()).fold(MATCHED) { (s, p) -> + val (arg1, arg2) = p + s and when { + arg1 == arg2 -> MATCHED + arg1 == null || arg2 == null -> UNMATCHED + else -> (arg1.getArguments().stream() zip arg2.getArguments().stream()).fold(MATCHED) { (s, p) -> + s and matchArguments(p.first, p.second) + } + } + } + } + + fun checkImplicitReceiver(implicitCall: ResolvedCall<*>, explicitCall: ResolvedCall<*>): Boolean { + val (implicitReceiver, explicitReceiver) = + when (explicitCall.getExplicitReceiverKind()) { + ExplicitReceiverKind.RECEIVER_ARGUMENT -> + (implicitCall.getReceiverArgument() as? ThisReceiver) to + (explicitCall.getReceiverArgument() as? ExpressionReceiver) + + ExplicitReceiverKind.THIS_OBJECT -> + (implicitCall.getThisObject() as? ThisReceiver) to + (explicitCall.getThisObject() as? ExpressionReceiver) + + else -> + null to null + } + + val thisExpression = explicitReceiver?.getExpression() as? JetThisExpression + if (implicitReceiver == null || thisExpression == null) return false + + return matchDescriptors( + implicitReceiver.getDeclarationDescriptor(), + thisExpression.getAdjustedResolvedCall()?.getCandidateDescriptor()?.getContainingDeclaration() + ) + } + + fun checkReceivers(): Boolean { + return when { + rc1.getExplicitReceiverKind() == rc2.getExplicitReceiverKind() -> { + matchReceivers(rc1.getReceiverArgument(), rc2.getReceiverArgument()) && + (rc1.getExplicitReceiverKind() == ExplicitReceiverKind.BOTH_RECEIVERS || matchReceivers(rc1.getThisObject(), rc2.getThisObject())) + } + + rc1.getExplicitReceiverKind() == ExplicitReceiverKind.NO_EXPLICIT_RECEIVER -> checkImplicitReceiver(rc1, rc2) + + rc2.getExplicitReceiverKind() == ExplicitReceiverKind.NO_EXPLICIT_RECEIVER -> checkImplicitReceiver(rc2, rc1) + + else -> false + } + } + + return when { + !checkSpecialOperations() -> UNMATCHED + !matchDescriptors(rc1.getCandidateDescriptor(), rc2.getCandidateDescriptor()) -> UNMATCHED + !checkReceivers() -> UNMATCHED + rc1.isSafeCall() != rc2.isSafeCall() -> UNMATCHED + else -> checkArguments() + } + } + + private val JetElement.bindingContext: BindingContext get() { + if ((getContainingFile() as? JetFile)?.doNotAnalyze != null) return BindingContext.EMPTY + return AnalyzerFacadeWithCache.getContextForElement(this) + } + + private fun JetElement.getAdjustedResolvedCall(): ResolvedCall<*>? { + val rc = getResolvedCall(bindingContext)?.let { + when { + it !is VariableAsFunctionResolvedCall -> it + this is JetSimpleNameExpression -> it.variableCall + else -> it.functionCall + } + } + + return when { + rc == null || ErrorUtils.isError(rc.getCandidateDescriptor()) -> null + else -> rc + } + } + + private fun matchCalls(e1: JetElement, e2: JetElement): Status? { + if (e1.shouldIgnoreResolvedCall() || e2.shouldIgnoreResolvedCall()) return null + + val resolvedCall1 = e1.getAdjustedResolvedCall() + val resolvedCall2 = e2.getAdjustedResolvedCall() + + return when { + resolvedCall1 != null && resolvedCall2 != null -> + matchResolvedCalls(resolvedCall1, resolvedCall2) + + resolvedCall1 == null && resolvedCall2 == null -> { + val call1 = e1.getCall(e1.bindingContext) + val call2 = e2.getCall(e2.bindingContext) + + when { + call1 != null && call2 != null -> + if (matchCalls(call1, call2)) null else UNMATCHED + + else -> + if (call1 == null && call2 == null) null else UNMATCHED + } + } + + else -> + UNMATCHED + } + } + + private fun JetTypeReference.getType(): JetType? { + val t = bindingContext[BindingContext.TYPE, this] + return if (t == null || t.isError()) null else t + } + + private fun matchTypes(type1: JetType?, type2: JetType?): Status? { + if (type1 != null && type2 != null) { + if (TypeUtils.equalTypes(type1, type2)) return MATCHED + + if (type1.isNullable() != type2.isNullable()) return UNMATCHED + if (!matchDescriptors( + type1.getConstructor().getDeclarationDescriptor(), + type2.getConstructor().getDeclarationDescriptor())) return UNMATCHED + + val args1 = type1.getArguments() + val args2 = type2.getArguments() + if (args1.size != args2.size) return UNMATCHED + if (!args1.zip(args2).all { + it.first.getProjectionKind() == it.second.getProjectionKind() && matchTypes(it.first.getType(), it.second.getType()) == MATCHED } + ) return UNMATCHED + + return MATCHED + } + + return if (type1 == null && type2 == null) null else UNMATCHED + } + + private fun matchTypes(types1: Collection, types2: Collection): Boolean { + fun sortTypes(types: Collection) = types.sortBy{ DescriptorRenderer.DEBUG_TEXT.renderType(it) } + + if (types1.size != types2.size) return false + return (sortTypes(types1) zip sortTypes(types2)).all { matchTypes(it.first, it.second) == MATCHED } + } + + private fun JetElement.shouldIgnoreResolvedCall(): Boolean { + return when { + this is JetConstantExpression -> true + this is JetIfExpression -> true + this is JetUnaryExpression -> when (getOperationReference().getReferencedNameElementType()) { + JetTokens.EXCLEXCL, JetTokens.PLUSPLUS, JetTokens.MINUSMINUS -> true + else -> false + } + this is JetBinaryExpression -> getOperationReference().getReferencedNameElementType() == JetTokens.ELVIS + else -> false + } + } + + private fun JetBinaryExpression.matchComplexAssignmentWithSimple(simple: JetBinaryExpression): Status? { + return when { + doUnify(getLeft(), simple.getLeft()) == UNMATCHED -> UNMATCHED + else -> simple.getRight()?.let { matchCalls(this, it) } ?: UNMATCHED + } + } + + private fun JetBinaryExpression.matchAssignment(e: JetElement): Status? { + val operationType = getOperationReference().getReferencedNameElementType() as JetToken + if (operationType == JetTokens.EQ) { + if (e.shouldIgnoreResolvedCall()) return UNMATCHED + + if (JetPsiUtil.isAssignment(e) && !JetPsiUtil.isOrdinaryAssignment(e)) { + return (e as JetBinaryExpression).matchComplexAssignmentWithSimple(this) + } + + val lhs = getLeft()?.unwrap() + if (lhs !is JetArrayAccessExpression) return null + + val setResolvedCall = bindingContext[BindingContext.INDEXED_LVALUE_SET, lhs] + val resolvedCallToMatch = e.getAdjustedResolvedCall() + + return if (setResolvedCall == null || resolvedCallToMatch == null) null else matchResolvedCalls(setResolvedCall, resolvedCallToMatch) + } + + val assignResolvedCall = getAdjustedResolvedCall() + if (assignResolvedCall == null) return UNMATCHED + + val operationName = OperatorConventions.getNameForOperationSymbol(operationType) + if (assignResolvedCall.getResultingDescriptor()?.getName() == operationName) return matchCalls(this, e) + + return if (JetPsiUtil.isAssignment(e)) null else UNMATCHED + } + + private fun matchLabelTargets(e1: JetLabelReferenceExpression, e2: JetLabelReferenceExpression): Status { + val target1 = e1.bindingContext[BindingContext.LABEL_TARGET, e1] + val target2 = e2.bindingContext[BindingContext.LABEL_TARGET, e2] + + return if (target1 == target2) MATCHED else UNMATCHED + } + + private fun PsiElement.isIncrement(): Boolean { + val parent = getParent() + return parent is JetUnaryExpression + && this == parent.getOperationReference() + && ((parent.getOperationToken() as JetToken) in OperatorConventions.INCREMENT_OPERATIONS) + } + + private fun matchCallableReferences(e1: JetCallableReferenceExpression, e2: JetCallableReferenceExpression): Boolean { + val d1 = e1.bindingContext[BindingContext.REFERENCE_TARGET, e1.getCallableReference()] + val d2 = e2.bindingContext[BindingContext.REFERENCE_TARGET, e2.getCallableReference()] + return matchDescriptors(d1, d2) + } + + private fun matchMultiDeclarations(e1: JetMultiDeclaration, e2: JetMultiDeclaration): Boolean { + val entries1 = e1.getEntries() + val entries2 = e2.getEntries() + if (entries1.size != entries2.size) return false + + return entries1.zip(entries2).all { p -> + val (entry1, entry2) = p + val rc1 = entry1.bindingContext[BindingContext.COMPONENT_RESOLVED_CALL, entry1] + val rc2 = entry2.bindingContext[BindingContext.COMPONENT_RESOLVED_CALL, entry2] + when { + rc1 == null && rc2 == null -> true + rc1 != null && rc2 != null -> matchResolvedCalls(rc1, rc2) == MATCHED + else -> false + } + } + } + + fun matchReceiverParameters(receiver1: ReceiverParameterDescriptor?, receiver2: ReceiverParameterDescriptor?): Boolean { + val matchedReceivers = when { + receiver1 == null && receiver2 == null -> true + matchDescriptors(receiver1, receiver2) -> true + receiver1 != null && receiver2 != null -> matchTypes(receiver1.getType(), receiver2.getType()) == MATCHED + else -> false + } + + if (matchedReceivers && receiver1 != null) { + declarationPatternsToTargets.putValue(receiver1, receiver2) + } + + return matchedReceivers + } + + private fun matchCallables( + decl1: JetDeclaration, + decl2: JetDeclaration, + desc1: CallableDescriptor, + desc2: CallableDescriptor): Status? { + fun needToCompareReturnTypes(): Boolean { + if (decl1 !is JetCallableDeclaration) return true + return decl1.getReturnTypeRef() != null || (decl2 as JetCallableDeclaration).getReturnTypeRef() != null + } + + if (desc1 is VariableDescriptor && desc1.isVar() != (desc2 as VariableDescriptor).isVar()) return UNMATCHED + + if (!matchNames(decl1, decl2, desc1, desc2)) return UNMATCHED + + if (needToCompareReturnTypes()) { + val type1 = desc1.getReturnType() + val type2 = desc2.getReturnType() + + if (type1 != type2 + && (type1 == null || type2 == null || type1.isError() || type2.isError() || matchTypes(type1, type2) != MATCHED)) { + return UNMATCHED + } + } + + if (!matchReceiverParameters(desc1.getReceiverParameter(), desc2.getReceiverParameter())) return UNMATCHED + if (!matchReceiverParameters(desc1.getExpectedThisObject(), desc2.getExpectedThisObject())) return UNMATCHED + + val params1 = desc1.getValueParameters() + val params2 = desc2.getValueParameters() + val zippedParams = params1.zip(params2) + val parametersMatch = + (params1.size == params2.size) && zippedParams.all { matchTypes(it.first.getType(), it.second.getType()) == MATCHED } + if (!parametersMatch) return UNMATCHED + + zippedParams.forEach { declarationPatternsToTargets.putValue(it.first, it.second) } + + return doUnify( + (decl1 as? JetTypeParameterListOwner)?.getTypeParameters()?.toRange() ?: Empty, + (decl2 as? JetTypeParameterListOwner)?.getTypeParameters()?.toRange() ?: Empty + ) and when (decl1) { + is JetDeclarationWithBody -> + doUnify(decl1.getBodyExpression(), (decl2 as JetDeclarationWithBody).getBodyExpression()) + + is JetWithExpressionInitializer -> + doUnify(decl1.getInitializer(), (decl2 as JetWithExpressionInitializer).getInitializer()) + + is JetParameter -> + doUnify(decl1.getDefaultValue(), (decl2 as JetParameter).getDefaultValue()) + + else -> + UNMATCHED + } + } + + private fun JetDeclaration.isNameRelevant(): Boolean { + if (this is JetParameter && hasValOrVarNode()) return true + + val parent = getParent() + return parent is JetClassBody || parent is JetFile + } + + private fun matchNames(decl1: JetDeclaration, decl2: JetDeclaration, desc1: DeclarationDescriptor, desc2: DeclarationDescriptor): Boolean { + return (!decl1.isNameRelevant() && !decl2.isNameRelevant()) || desc1.getName() == desc2.getName() + } + + private fun matchContainedDescriptors( + declarations1: List, + declarations2: List, + matchPair: (Pair) -> Boolean + ): Boolean { + val zippedParams = declarations1 zip declarations2 + if (declarations1.size != declarations2.size || !zippedParams.all { matchPair(it) }) return false + + zippedParams.forEach { declarationPatternsToTargets.putValue(it.first, it.second) } + return true + } + + private fun matchClasses( + decl1: JetClassOrObject, + decl2: JetClassOrObject, + desc1: ClassDescriptor, + desc2: ClassDescriptor): Status? { + class OrderInfo( + val orderSensitive: List, + val orderInsensitive: List + ) + + fun getMemberOrderInfo(cls: JetClassOrObject): OrderInfo { + val (orderInsensitive, orderSensitive) = (cls.getBody()?.getDeclarations() ?: Collections.emptyList()).partition { + it is JetClassOrObject || it is JetFunction + } + + return OrderInfo(orderSensitive, orderInsensitive) + } + + fun getDelegationOrderInfo(cls: JetClassOrObject): OrderInfo { + val (orderInsensitive, orderSensitive) = cls.getDelegationSpecifiers().partition { it is JetDelegatorToSuperClass } + return OrderInfo(orderSensitive, orderInsensitive) + } + + fun resolveAndSortDeclarationsByDescriptor(declarations: List): List> { + return declarations + .map { it to it.bindingContext[BindingContext.DECLARATION_TO_DESCRIPTOR, it] } + .sortBy { it.second?.let { DescriptorRenderer.SOURCE_CODE.render(it) } ?: "" } + } + + fun sortDeclarationsByElementType(declarations: List): List { + return declarations.sortBy { it.getNode()?.getElementType()?.getIndex() ?: -1 } + } + + if (desc1.getKind() != desc2.getKind()) return UNMATCHED + if (!matchNames(decl1, decl2, desc1, desc2)) return UNMATCHED + + declarationPatternsToTargets.putValue(desc1.getThisAsReceiverParameter(), desc2.getThisAsReceiverParameter()) + + val constructor1 = desc1.getUnsubstitutedPrimaryConstructor() + val constructor2 = desc2.getUnsubstitutedPrimaryConstructor() + if (constructor1 != null && constructor2 != null) { + declarationPatternsToTargets.putValue(constructor1, constructor2) + } + + val delegationInfo1 = getDelegationOrderInfo(decl1) + val delegationInfo2 = getDelegationOrderInfo(decl2) + + if (delegationInfo1.orderInsensitive.size != delegationInfo2.orderInsensitive.size) return UNMATCHED + @outer + for (specifier1 in delegationInfo1.orderInsensitive) { + for (specifier2 in delegationInfo2.orderInsensitive) { + if (doUnify(specifier1, specifier2) != UNMATCHED) continue @outer + } + return UNMATCHED + } + + val status = doUnify((decl1 as? JetClass)?.getPrimaryConstructorParameterList(), (decl2 as? JetClass)?.getPrimaryConstructorParameterList()) and + doUnify((decl1 as? JetClass)?.getTypeParameterList(), (decl2 as? JetClass)?.getTypeParameterList()) and + doUnify(delegationInfo1.orderSensitive.toRange(), delegationInfo2.orderSensitive.toRange()) + if (status == UNMATCHED) return UNMATCHED + + val membersInfo1 = getMemberOrderInfo(decl1) + val membersInfo2 = getMemberOrderInfo(decl2) + + val sortedMembers1 = resolveAndSortDeclarationsByDescriptor(membersInfo1.orderInsensitive) + val sortedMembers2 = resolveAndSortDeclarationsByDescriptor(membersInfo2.orderInsensitive) + if ((sortedMembers1.size != sortedMembers2.size)) return UNMATCHED + if (sortedMembers1.zip(sortedMembers2).any { + val (d1, d2) = it + (matchDeclarations(d1.first, d2.first, d1.second, d2.second) ?: doUnify(d1.first, d2.first)) == UNMATCHED + }) return UNMATCHED + + return doUnify( + sortDeclarationsByElementType(membersInfo1.orderSensitive).toRange(), + sortDeclarationsByElementType(membersInfo2.orderSensitive).toRange() + ) + } + + private fun matchTypeParameters( + desc1: TypeParameterDescriptor, + desc2: TypeParameterDescriptor + ): Status { + if (desc1.getVariance() != desc2.getVariance()) return UNMATCHED + if (!matchTypes(desc1.getLowerBounds(), desc2.getLowerBounds())) return UNMATCHED + if (!matchTypes(desc1.getUpperBounds(), desc2.getUpperBounds())) return UNMATCHED + return MATCHED + } + + private fun JetDeclaration.matchDeclarations(e: PsiElement): Status? { + if (e !is JetDeclaration) return UNMATCHED + + val desc1 = bindingContext[BindingContext.DECLARATION_TO_DESCRIPTOR, this] + val desc2 = e.bindingContext[BindingContext.DECLARATION_TO_DESCRIPTOR, e] + return matchDeclarations(this, e, desc1, desc2) + } + + private fun matchDeclarations( + decl1: JetDeclaration, + decl2: JetDeclaration, + desc1: DeclarationDescriptor?, + desc2: DeclarationDescriptor?): Status? { + if (decl1.javaClass != decl2.javaClass) return UNMATCHED + + if (desc1 == null || desc2 == null || ErrorUtils.isError(desc1) || ErrorUtils.isError(desc2)) return UNMATCHED + if (desc1.javaClass != desc2.javaClass) return UNMATCHED + + declarationPatternsToTargets.putValue(desc1, desc2) + val status = when (decl1) { + is JetDeclarationWithBody, is JetWithExpressionInitializer, is JetParameter -> + matchCallables(decl1, decl2, desc1 as CallableDescriptor, desc2 as CallableDescriptor) + + is JetClassOrObject -> + matchClasses(decl1, decl2 as JetClassOrObject, desc1 as ClassDescriptor, desc2 as ClassDescriptor) + + is JetTypeParameter -> + matchTypeParameters(desc1 as TypeParameterDescriptor, desc2 as TypeParameterDescriptor) + + else -> + null + } + if (status == UNMATCHED) { + declarationPatternsToTargets.removeValue(desc1, desc2) + } + + return status + } + + private fun matchResolvedInfo(e1: PsiElement, e2: PsiElement): Status? { + return when { + e1 !is JetElement, e2 !is JetElement -> + null + + e1 is JetMultiDeclaration && e2 is JetMultiDeclaration -> + if (matchMultiDeclarations(e1, e2)) null else UNMATCHED + + e1 is JetClassObject && e2 is JetClassObject -> + e1.getObjectDeclaration().matchDeclarations(e2.getObjectDeclaration()) + + e1 is JetClassInitializer && e2 is JetClassInitializer -> + null + + e1 is JetDeclaration -> + e1.matchDeclarations(e2) + + e2 is JetDeclaration -> + e2.matchDeclarations(e1) + + e1 is JetTypeReference && e2 is JetTypeReference -> + matchTypes(e1.getType(), e2.getType()) + + JetPsiUtil.isAssignment(e1) -> + (e1 as JetBinaryExpression).matchAssignment(e2) + + JetPsiUtil.isAssignment(e2) -> + (e2 as JetBinaryExpression).matchAssignment(e1) + + e1 is JetLabelReferenceExpression && e2 is JetLabelReferenceExpression -> + matchLabelTargets(e1, e2) + + e1.isIncrement() != e2.isIncrement() -> + UNMATCHED + + e1 is JetCallableReferenceExpression && e2 is JetCallableReferenceExpression -> + if (matchCallableReferences(e1, e2)) MATCHED else UNMATCHED + + else -> + matchCalls(e1, e2) + } + } + + private fun PsiElement.checkType(parameter: UnifierParameter): Boolean { + val targetElementType = (this as? JetExpression)?.let { it.bindingContext[BindingContext.EXPRESSION_TYPE, it] } + return targetElementType != null && JetTypeChecker.DEFAULT.isSubtypeOf(targetElementType, parameter.expectedType) + } + + fun doUnify(target: JetPsiRange, pattern: JetPsiRange): Status { + val targetElements = target.elements + val patternElements = pattern.elements + if (targetElements.size != patternElements.size) return UNMATCHED + + return (targetElements.stream() zip patternElements.stream()).fold(MATCHED) {(s, p) -> + if (s != UNMATCHED) s and doUnify(p.first, p.second) else s + } + } + + private fun ASTNode.getChildrenRange(): JetPsiRange = + getChildren(null).map { it.getPsi() }.filterNotNull().toRange() + + fun doUnify( + targetElement: PsiElement?, + patternElement: PsiElement? + ): Status { + val targetElementUnwrapped = targetElement?.unwrap() + val patternElementUnwrapped = patternElement?.unwrap() + + if (targetElementUnwrapped == patternElementUnwrapped) return MATCHED + if (targetElementUnwrapped == null || patternElementUnwrapped == null) return UNMATCHED + + if (!checkEquivalence) { + val referencedPatternDescriptor = (patternElementUnwrapped as? JetReferenceExpression)?.let { + it.bindingContext[BindingContext.REFERENCE_TARGET, it] + } + val parameter = descriptorToParameter[referencedPatternDescriptor] + if (parameter != null) { + if (targetElementUnwrapped !is JetExpression) return UNMATCHED + if (!targetElementUnwrapped.checkType(parameter)) return UNMATCHED + + val existingArgument = substitution[parameter] + return when { + existingArgument == null -> { + substitution[parameter] = targetElementUnwrapped + MATCHED + } + else -> { + checkEquivalence = true + val status = doUnify(existingArgument, targetElementUnwrapped) + checkEquivalence = false + + status + } + } + } + } + + val targetNode = targetElementUnwrapped.getNode() + val patternNode = patternElementUnwrapped.getNode() + if (targetNode == null || patternNode == null) return UNMATCHED + + val resolvedStatus = matchResolvedInfo(targetElementUnwrapped, patternElementUnwrapped) + if (resolvedStatus != null) return resolvedStatus + + if (targetNode.getElementType() != patternNode.getElementType()) return UNMATCHED + + val targetChildren = targetNode.getChildrenRange() + val patternChildren = patternNode.getChildrenRange() + + if (patternChildren.empty && targetChildren.empty) { + return if (targetElementUnwrapped.unquotedText() == patternElementUnwrapped.unquotedText()) MATCHED else UNMATCHED + } + + return doUnify(targetChildren, patternChildren) + } + } + + private val descriptorToParameter = ContainerUtil.newMapFromValues(parameters.iterator()) { it!!.descriptor } + + private fun PsiElement.unwrap(): PsiElement? { + return when (this) { + is JetExpression -> JetPsiUtil.deparenthesize(this) + is JetStringTemplateEntryWithExpression -> JetPsiUtil.deparenthesize(getExpression()) + else -> this + } + } + + private fun PsiElement.unquotedText(): String { + val text = getText() ?: "" + return if (this is LeafPsiElement) JetPsiUtil.unquoteIdentifier(text) else text + } + + public fun unify(target: JetPsiRange, pattern: JetPsiRange): UnificationResult { + return with(Context(target, pattern)) { + val status = doUnify(target, pattern) + when { + substitution.size != descriptorToParameter.size -> + Unmatched + status == MATCHED -> + Matched(substitution) + else -> + Unmatched + } + } + } + + public fun unify(targetElement: PsiElement?, patternElement: PsiElement?): UnificationResult = + unify(targetElement.toRange(), patternElement.toRange()) +} + +public fun PsiElement?.matches(e: PsiElement?): Boolean = JetPsiUnifier.DEFAULT.unify(this, e).matched +public fun JetPsiRange.matches(r: JetPsiRange): Boolean = JetPsiUnifier.DEFAULT.unify(this, r).matched