diff --git a/detekt-gradle-plugin/src/main/kotlin/io/gitlab/arturbosch/detekt/invoke/DetektInvoker.kt b/detekt-gradle-plugin/src/main/kotlin/io/gitlab/arturbosch/detekt/invoke/DetektInvoker.kt index 5c4b055fa..a4a27e92b 100644 --- a/detekt-gradle-plugin/src/main/kotlin/io/gitlab/arturbosch/detekt/invoke/DetektInvoker.kt +++ b/detekt-gradle-plugin/src/main/kotlin/io/gitlab/arturbosch/detekt/invoke/DetektInvoker.kt @@ -66,8 +66,7 @@ internal class DefaultCliInvoker( } } - private fun isBuildFailure(msg: String?) = - msg != null && "Build failed with" in msg && "issues" in msg + private fun isBuildFailure(msg: String) = "Build failed with" in msg && "issues" in msg } private class DryRunInvoker(private val logger: Logger) : DetektInvoker { diff --git a/detekt-rules-errorprone/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/bugs/UselessPostfixExpression.kt b/detekt-rules-errorprone/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/bugs/UselessPostfixExpression.kt index a59145222..2989e2fcc 100644 --- a/detekt-rules-errorprone/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/bugs/UselessPostfixExpression.kt +++ b/detekt-rules-errorprone/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/bugs/UselessPostfixExpression.kt @@ -78,15 +78,17 @@ class UselessPostfixExpression(config: Config = Config.empty) : Rule(config) { report(postfixExpression) } - getPostfixExpressionChildren(expression.returnedExpression) - ?.forEach { report(it) } + expression.returnedExpression + ?.let(this::getPostfixExpressionChildren) + ?.forEach(this::report) } override fun visitBinaryExpression(expression: KtBinaryExpression) { val postfixExpression = expression.right?.asPostFixExpression() val leftIdentifierText = expression.left?.text - checkPostfixExpression(postfixExpression, leftIdentifierText) - getPostfixExpressionChildren(expression.right) + postfixExpression?.let { checkPostfixExpression(it, leftIdentifierText) } + expression.right + ?.let(this::getPostfixExpressionChildren) ?.forEach { checkPostfixExpression(it, leftIdentifierText) } } @@ -94,8 +96,8 @@ class UselessPostfixExpression(config: Config = Config.empty) : Rule(config) { (operationToken === PLUSPLUS || operationToken === MINUSMINUS) ) this else null - private fun checkPostfixExpression(postfixExpression: KtPostfixExpression?, leftIdentifierText: String?) { - if (postfixExpression != null && leftIdentifierText == postfixExpression.firstChild?.text) { + private fun checkPostfixExpression(postfixExpression: KtPostfixExpression, leftIdentifierText: String?) { + if (leftIdentifierText == postfixExpression.firstChild?.text) { report(postfixExpression) } } @@ -118,7 +120,7 @@ class UselessPostfixExpression(config: Config = Config.empty) : Rule(config) { ) } - private fun getPostfixExpressionChildren(expression: KtExpression?) = - expression?.getChildrenOfType() - ?.filter { it.operationToken === PLUSPLUS || it.operationToken === MINUSMINUS } + private fun getPostfixExpressionChildren(expression: KtExpression) = + expression.getChildrenOfType() + .filter { it.operationToken === PLUSPLUS || it.operationToken === MINUSMINUS } } diff --git a/detekt-rules-performance/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/performance/ArrayPrimitive.kt b/detekt-rules-performance/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/performance/ArrayPrimitive.kt index 5289a8482..fc3d117ee 100644 --- a/detekt-rules-performance/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/performance/ArrayPrimitive.kt +++ b/detekt-rules-performance/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/performance/ArrayPrimitive.kt @@ -65,15 +65,14 @@ class ArrayPrimitive(config: Config = Config.empty) : Rule(config) { override fun visitNamedDeclaration(declaration: KtNamedDeclaration) { super.visitNamedDeclaration(declaration) if (declaration is KtCallableDeclaration) { - reportArrayPrimitives(declaration.typeReference) - reportArrayPrimitives(declaration.receiverTypeReference) + declaration.typeReference?.let(this::reportArrayPrimitives) + declaration.receiverTypeReference?.let(this::reportArrayPrimitives) } } - private fun reportArrayPrimitives(typeReference: KtTypeReference?) { - typeReference - ?.collectDescendantsOfType { isArrayPrimitive(it) } - ?.forEach { report(CodeSmell(issue, Entity.from(it), issue.description)) } + private fun reportArrayPrimitives(typeReference: KtTypeReference) { + typeReference.collectDescendantsOfType { isArrayPrimitive(it) } + .forEach { report(CodeSmell(issue, Entity.from(it), issue.description)) } } private fun isArrayPrimitive(descriptor: CallableDescriptor): Boolean { diff --git a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt index d5f2e5512..42190e44b 100644 --- a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt +++ b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullable.kt @@ -9,20 +9,40 @@ import io.gitlab.arturbosch.detekt.api.Issue import io.gitlab.arturbosch.detekt.api.Rule import io.gitlab.arturbosch.detekt.api.Severity import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution +import io.gitlab.arturbosch.detekt.rules.isNonNullCheck +import io.gitlab.arturbosch.detekt.rules.isNullCheck import io.gitlab.arturbosch.detekt.rules.isOpen +import org.jetbrains.kotlin.descriptors.CallableDescriptor +import org.jetbrains.kotlin.descriptors.DeclarationDescriptor import org.jetbrains.kotlin.descriptors.PropertyDescriptor import org.jetbrains.kotlin.lexer.KtTokens import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.psi.KtBinaryExpression +import org.jetbrains.kotlin.psi.KtCallExpression import org.jetbrains.kotlin.psi.KtClass import org.jetbrains.kotlin.psi.KtConstantExpression +import org.jetbrains.kotlin.psi.KtDotQualifiedExpression import org.jetbrains.kotlin.psi.KtExpression import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtIfExpression +import org.jetbrains.kotlin.psi.KtIsExpression +import org.jetbrains.kotlin.psi.KtNameReferenceExpression +import org.jetbrains.kotlin.psi.KtNamedFunction +import org.jetbrains.kotlin.psi.KtNullableType +import org.jetbrains.kotlin.psi.KtParameter +import org.jetbrains.kotlin.psi.KtPostfixExpression import org.jetbrains.kotlin.psi.KtProperty import org.jetbrains.kotlin.psi.KtPropertyAccessor import org.jetbrains.kotlin.psi.KtPropertyDelegate +import org.jetbrains.kotlin.psi.KtQualifiedExpression import org.jetbrains.kotlin.psi.KtReturnExpression +import org.jetbrains.kotlin.psi.KtSafeQualifiedExpression +import org.jetbrains.kotlin.psi.KtTypeReference +import org.jetbrains.kotlin.psi.KtWhenCondition +import org.jetbrains.kotlin.psi.KtWhenConditionIsPattern +import org.jetbrains.kotlin.psi.KtWhenConditionWithExpression +import org.jetbrains.kotlin.psi.KtWhenExpression +import org.jetbrains.kotlin.psi.psiUtil.allChildren import org.jetbrains.kotlin.psi.psiUtil.collectDescendantsOfType import org.jetbrains.kotlin.psi.psiUtil.isPrivate import org.jetbrains.kotlin.resolve.BindingContext @@ -49,6 +69,16 @@ import org.jetbrains.kotlin.types.isNullable * val a: Int? * get() = 5 * } + * + * fun foo(a: Int?) { + * val b = a!! + 2 + * } + * + * fun foo(a: Int?) { + * if (a != null) { + * println(a) + * } + * } * * * @@ -64,6 +94,14 @@ import org.jetbrains.kotlin.types.isNullable * val a: Int * get() = 5 * } + * + * fun foo(a: Int) { + * val b = a + 2 + * } + * + * fun foo(a: Int) { + * println(a) + * } * */ @RequiresTypeResolution @@ -76,12 +114,301 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { ) override fun visitKtFile(file: KtFile) { - if (bindingContext == BindingContext.EMPTY) return - NonNullableCheckVisitor().visitKtFile(file) super.visitKtFile(file) + PropertyCheckVisitor().visitKtFile(file) + ParameterCheckVisitor().visitKtFile(file) } - private inner class NonNullableCheckVisitor : DetektVisitor() { + @Suppress("TooManyFunctions") + private inner class ParameterCheckVisitor : DetektVisitor() { + private val nullableParams = mutableMapOf() + + override fun visitNamedFunction(function: KtNamedFunction) { + val candidateDescriptors = mutableSetOf() + function.valueParameters.asSequence() + .filter { + it.typeReference?.typeElement is KtNullableType + }.mapNotNull { parameter -> + bindingContext[BindingContext.DECLARATION_TO_DESCRIPTOR, parameter]?.let { + it to parameter + } + }.forEach { (descriptor, param) -> + candidateDescriptors.add(descriptor) + nullableParams[descriptor] = NullableParam(param) + } + + val validSingleChildExpression = if (function.initializer == null) { + val children = function.bodyBlockExpression + ?.allChildren + ?.filterIsInstance() + ?.toList() + .orEmpty() + if (children.size == 1) { + children.first().determineSingleExpression(candidateDescriptors) + } else { + INELIGIBLE_SINGLE_EXPRESSION + } + } else { + INELIGIBLE_SINGLE_EXPRESSION + } + + // Evaluate the function, then analyze afterwards whether the candidate properties + // could be made non-nullable. + super.visitNamedFunction(function) + + candidateDescriptors.asSequence() + .mapNotNull(nullableParams::remove) + // The heuristic for whether a nullable param can be made non-nullable is: + // * It has been forced into a non-null type, either by `!!` or by + // `checkNonNull()`/`requireNonNull()`, or + // * The containing function only consists of a single non-null check on + // the param, either via an if/when check or with a safe-qualified expression. + .filter { + val onlyNonNullCheck = validSingleChildExpression && it.isNonNullChecked && !it.isNullChecked + it.isNonNullForced || onlyNonNullCheck + }.forEach { nullableParam -> + report( + CodeSmell( + issue, + Entity.from(nullableParam.param), + "The nullable parameter '${nullableParam.param.name}' can be made non-nullable." + ) + ) + } + } + + override fun visitCallExpression(expression: KtCallExpression) { + val calleeName = expression.calleeExpression + .getResolvedCall(bindingContext) + ?.resultingDescriptor + ?.name + ?.toString() + // Check for whether a call to `checkNonNull()` or `requireNonNull()` has + // been made. + if (calleeName == REQUIRE_NOT_NULL_NAME || calleeName == CHECK_NOT_NULL_NAME) { + expression.valueArguments.forEach { valueArgument -> + valueArgument.getArgumentExpression()?.let { argumentExpression -> + updateNullableParam(argumentExpression) { it.isNonNullForced = true } + } + } + } + super.visitCallExpression(expression) + } + + override fun visitPostfixExpression(expression: KtPostfixExpression) { + if (expression.operationToken == KtTokens.EXCLEXCL) { + expression.baseExpression?.let { baseExpression -> + updateNullableParam(baseExpression) { it.isNonNullForced = true } + } + } + super.visitPostfixExpression(expression) + } + + override fun visitWhenExpression(expression: KtWhenExpression) { + val subjectDescriptor = expression.subjectExpression + ?.let { it as? KtNameReferenceExpression } + ?.getResolvedCall(bindingContext) + ?.resultingDescriptor + val whenConditions = expression.entries.flatMap { it.conditions.asList() } + if (subjectDescriptor != null) { + whenConditions.evaluateSubjectWhenExpression(expression, subjectDescriptor) + } else { + whenConditions.forEach { whenCondition -> + if (whenCondition is KtWhenConditionWithExpression) { + whenCondition.expression.evaluateCheckStatement(expression.elseExpression) + } + } + } + super.visitWhenExpression(expression) + } + + override fun visitIfExpression(expression: KtIfExpression) { + expression.condition.evaluateCheckStatement(expression.`else`) + super.visitIfExpression(expression) + } + + override fun visitSafeQualifiedExpression(expression: KtSafeQualifiedExpression) { + updateNullableParam(expression.receiverExpression) { it.isNonNullChecked = true } + super.visitSafeQualifiedExpression(expression) + } + + override fun visitDotQualifiedExpression(expression: KtDotQualifiedExpression) { + val isExtensionForNullable = expression.getResolvedCall(bindingContext) + ?.resultingDescriptor + ?.extensionReceiverParameter + ?.type + ?.isMarkedNullable + if (isExtensionForNullable == true) { + expression.receiverExpression + .getRootExpression() + ?.let { rootExpression -> + updateNullableParam(rootExpression) { it.isNullChecked = true } + } + } + super.visitDotQualifiedExpression(expression) + } + + override fun visitBinaryExpression(expression: KtBinaryExpression) { + if (expression.operationToken == KtTokens.ELVIS) { + expression.left + .getRootExpression() + ?.let { rootExpression -> + updateNullableParam(rootExpression) { it.isNullChecked = true } + } + } + super.visitBinaryExpression(expression) + } + + private fun KtExpression?.getRootExpression(): KtExpression? { + // Look for the expression that was the root of a potential call chain. + var receiverExpression = this + while (receiverExpression is KtQualifiedExpression) { + receiverExpression = receiverExpression.receiverExpression + } + return receiverExpression + } + + private fun KtExpression?.determineSingleExpression(candidateDescriptors: Set): Boolean { + return when (this) { + is KtReturnExpression -> INELIGIBLE_SINGLE_EXPRESSION + is KtIfExpression -> ELIGIBLE_SINGLE_EXPRESSION + is KtDotQualifiedExpression -> { + this.getRootExpression() + .getResolvedCall(bindingContext) + ?.resultingDescriptor + ?.let(candidateDescriptors::contains) == true + } + is KtCallExpression -> INELIGIBLE_SINGLE_EXPRESSION + else -> ELIGIBLE_SINGLE_EXPRESSION + } + } + + private fun KtExpression?.getNonNullChecks(): List? { + return when (this) { + is KtBinaryExpression -> evaluateBinaryExpression() + is KtIsExpression -> evaluateIsExpression() + else -> null + } + } + + private fun KtExpression?.evaluateCheckStatement(elseExpression: KtExpression?) { + this.getNonNullChecks()?.let { nonNullChecks -> + val nullableParamCallback = if (elseExpression.isValidElseExpression()) { + { nullableParam: NullableParam -> + nullableParam.isNonNullChecked = true + nullableParam.isNullChecked = true + } + } else { + { nullableParam -> nullableParam.isNonNullChecked = true } + } + nonNullChecks.forEach { nullableParams[it]?.let(nullableParamCallback) } + } + } + + // Helper function for if- and when-statements that will recursively check for whether + // any function params have been checked for being a non-nullable type. + private fun KtBinaryExpression.evaluateBinaryExpression(): List { + val leftExpression = left + val rightExpression = right + val nonNullChecks = mutableListOf() + + fun getDescriptor(leftExpression: KtExpression?, rightExpression: KtExpression?): CallableDescriptor? { + return when { + leftExpression is KtNameReferenceExpression -> leftExpression + rightExpression is KtNameReferenceExpression -> rightExpression + else -> null + }?.getResolvedCall(bindingContext) + ?.resultingDescriptor + } + + if (isNullCheck()) { + getDescriptor(leftExpression, rightExpression) + ?.let { nullableParams[it] } + ?.let { it.isNullChecked = true } + } else if (isNonNullCheck()) { + getDescriptor(leftExpression, rightExpression)?.let(nonNullChecks::add) + } + + // Recursively iterate into the if-check if possible + leftExpression.getNonNullChecks()?.let(nonNullChecks::addAll) + rightExpression.getNonNullChecks()?.let(nonNullChecks::addAll) + return nonNullChecks + } + + private fun KtIsExpression.evaluateIsExpression(): List { + val descriptor = this.leftHandSide.getResolvedCall(bindingContext)?.resultingDescriptor + ?: return emptyList() + return if (isNullableCheck(typeReference, isNegated)) { + nullableParams[descriptor]?.let { it.isNullChecked = true } + emptyList() + } else { + listOf(descriptor) + } + } + + private fun List.evaluateSubjectWhenExpression( + expression: KtWhenExpression, + subjectDescriptor: CallableDescriptor + ) { + var isNonNullChecked = false + var isNullChecked = false + forEach { whenCondition -> + when (whenCondition) { + is KtWhenConditionWithExpression -> { + if (whenCondition.expression?.text == "null") { + isNullChecked = true + } + } + is KtWhenConditionIsPattern -> { + if (isNullableCheck(whenCondition.typeReference, whenCondition.isNegated)) { + isNullChecked = true + } else { + isNonNullChecked = true + } + } + } + } + if (expression.elseExpression.isValidElseExpression()) { + if (isNullChecked) { + isNonNullChecked = true + } else if (isNonNullChecked) { + isNullChecked = true + } + } + nullableParams[subjectDescriptor]?.let { + if (isNullChecked) it.isNullChecked = true + if (isNonNullChecked) it.isNonNullChecked = true + } + } + + private fun isNullableCheck(typeReference: KtTypeReference?, isNegated: Boolean): Boolean { + val isNullable = typeReference.isNullable(bindingContext) + return (isNullable && !isNegated) || (!isNullable && isNegated) + } + + private fun KtExpression?.isValidElseExpression(): Boolean { + return this != null && this !is KtIfExpression && this !is KtWhenExpression + } + + private fun KtTypeReference?.isNullable(bindingContext: BindingContext): Boolean { + return this?.let { bindingContext[BindingContext.TYPE, it] }?.isMarkedNullable == true + } + + private fun updateNullableParam(expression: KtExpression, updateCallback: (NullableParam) -> Unit) { + expression.getResolvedCall(bindingContext) + ?.resultingDescriptor + ?.let { nullableParams[it] } + ?.let(updateCallback) + } + } + + private class NullableParam(val param: KtParameter) { + var isNullChecked = false + var isNonNullChecked = false + var isNonNullForced = false + } + + private inner class PropertyCheckVisitor : DetektVisitor() { // A list of properties that are marked as nullable during their // declaration but do not explicitly receive a nullable value in // the declaration, so they could potentially be marked as non-nullable @@ -112,11 +439,13 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { } override fun visitProperty(property: KtProperty) { - if (property.getKotlinTypeForComparison(bindingContext)?.isNullable() != true) return - val fqName = property.fqName ?: return - if (property.isCandidate()) { - candidateProps[fqName] = property + if (property.getKotlinTypeForComparison(bindingContext)?.isNullable() == true) { + val fqName = property.fqName + if (property.isCandidate() && fqName != null) { + candidateProps[fqName] = property + } } + super.visitProperty(property) } override fun visitBinaryExpression(expression: KtBinaryExpression) { @@ -182,4 +511,12 @@ class CanBeNonNullable(config: Config = Config.empty) : Rule(config) { } } } + + private companion object { + private const val REQUIRE_NOT_NULL_NAME = "requireNotNull" + private const val CHECK_NOT_NULL_NAME = "checkNotNull" + + private const val INELIGIBLE_SINGLE_EXPRESSION = false + private const val ELIGIBLE_SINGLE_EXPRESSION = true + } } diff --git a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/SpacingBetweenPackageAndImports.kt b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/SpacingBetweenPackageAndImports.kt index ee7db1a5c..f75d6259a 100644 --- a/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/SpacingBetweenPackageAndImports.kt +++ b/detekt-rules-style/src/main/kotlin/io/gitlab/arturbosch/detekt/rules/style/SpacingBetweenPackageAndImports.kt @@ -85,7 +85,7 @@ class SpacingBetweenPackageAndImports(config: Config = Config.empty) : Rule(conf } } - private fun checkLinebreakAfterElement(element: PsiElement?, message: String) { + private fun checkLinebreakAfterElement(element: PsiElement, message: String) { if (element is PsiWhiteSpace || element is KtElement) { val count = element.text.count { it == '\n' } if (count != 2) { diff --git a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt index 31d8b28b4..50a296cd3 100644 --- a/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt +++ b/detekt-rules-style/src/test/kotlin/io/gitlab/arturbosch/detekt/rules/style/CanBeNonNullableSpec.kt @@ -77,7 +77,7 @@ class CanBeNonNullableSpec : Spek({ } } """ - assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(2) } it("reports when file-level vars are never assigned nullable values") { @@ -383,5 +383,383 @@ class CanBeNonNullableSpec : Spek({ """ assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() } + + context("nullable function parameters") { + context("using a de-nullifier") { + it("does report when a param is de-nullified with a postfix expression") { + val code = """ + fun foo(a: Int?) { + val b = a!! + 2 + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does report when a param is de-nullified with a dot-qualified expression") { + val code = """ + fun foo(a: Int?) { + val b = a!!.plus(2) + } + + fun fizz(b: Int?) = b!!.plus(2) + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(2) + } + + it("does report when a de-nullifier precondition is called on the param") { + val code = """ + fun foo(a: Int?, b: Int?) { + val aNonNull = requireNotNull(a) + val c = aNonNull + checkNotNull(b) + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(2) + } + + it("does not report a double-bang call the field of a non-null param") { + val code = """ + class A(val a: Int?) + + fun foo(a: A) { + val b = a.a!! + 2 + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + + context("using a null-safe expression") { + context("in initializer") { + it("does not report when the safe-qualified expression is the only expression of the function") { + val code = """ + class A { + val foo = "BAR" + } + + fun foo(a: A?) = a?.foo + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + + context("in a non-return statement") { + it("does report when the safe-qualified expression is the only expression of the function") { + val code = """ + class A(val foo: String) + + fun foo(a: A?) { + a?.let { println(it.foo) } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does not report when the safe-qualified expression is within a lambda") { + val code = """ + class A { + fun doFoo(callback: () -> Unit) { + callback.invoke() + } + } + + fun foo(a: String?, aObj: A) { + aObj.doFoo { + a?.let { println("a not null") } + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report when the safe-qualified expression is not the only expression of the function") { + val code = """ + class A { + fun doFoo() { println("FOO") } + } + + fun foo(a: A?) { + a?.doFoo() + val b = 5 + 2 + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + context("in a return statement") { + it("does not report when the safe-qualified expression is the only expression of the function") { + val code = """ + class A { + val foo = "BAR" + } + + fun fizz(aObj: A?): String? { + return aObj?.foo + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + } + + context("when statements") { + context("without a subject") { + it("does not report when the parameter is checked on nullity") { + val code = """ + fun foo(a: Int?) { + when { + a == null -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report when the parameter is checked on nullity in a reversed manner") { + val code = """ + fun foo(a: Int?) { + when { + null == a -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report when the parameter is checked on nullity with multiple clauses") { + val code = """ + fun foo(a: Int?, other: Int) { + when { + a == null && other % 2 == 0 -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does report when the parameter is only checked on non-nullity") { + val code = """ + fun foo(a: Int?) { + when { + a != null -> println(2 + a) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does report when the parameter is only checked on non-nullity with multiple clauses") { + val code = """ + fun foo(a: Int?) { + when { + a != null && a % 2 == 0 -> println(2 + a) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does not report when the parameter is checked on non-nullity with an else statement") { + val code = """ + fun foo(a: Int?) { + when { + a != null -> println(2 + a) + else -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report on nullable type matching") { + val code = """ + fun foo(a: Int?) { + when { + a !is Int -> println("a is null") + } + } + + fun fizz(b: Int?) { + when { + b is Int? -> println("b is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does report on non-null type matching") { + val code = """ + fun foo(a: Int?) { + when { + a is Int -> println(2 + a) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does report on non-null type matching with multiple clauses") { + val code = """ + fun foo(a: Int?) { + when { + a is Int && a % 2 == 0 -> println(2 + a) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does not report on non-null type matching with an else statement") { + val code = """ + fun foo(a: Int?) { + when { + a is Int -> println(2 + a) + else -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + + context("with a subject") { + it("does not report when the parameter is checked on nullity") { + val code = """ + fun foo(a: Int?) { + when (a) { + null -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report on nullable type matching") { + val code = """ + fun foo(a: Int?) { + when (a) { + !is Int -> println("a is null") + } + } + + fun fizz(b: Int?) { + when (b) { + is Int? -> println("b is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does report on non-null type matching") { + val code = """ + fun foo(a: Int?) { + when(a) { + is Int -> println(2 + a) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does not report on non-null type matching with an else statement") { + val code = """ + fun foo(a: Int?) { + when(a) { + is Int -> println(2 + a) + else -> println("a is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + } + + context("if-statements") { + it("does not report when the parameter is checked on nullity") { + val code = """ + fun foo(a: Int?) { + if (a == null) { + println("'a' is null") + } + } + + fun fizz(a: Int?) { + if (null == a) { + println("'a' is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report when the if-check is in the else statement") { + val code = """ + fun foo(num: Int, a: Int?) { + if (num % 2 == 0) { + println("'num' is even") + } else if (a == null) { + println("'a' is null") + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does report when the parameter is only checked on non-nullity in a function") { + val code = """ + fun foo(a: Int?) { + if (a != null) { + println(a + 5) + } + } + + fun fizz(a: Int?) { + if (null != a) { + println(a + 5) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(2) + } + + it("does report when the parameter is only checked on non-nullity with multiple clauses") { + val code = """ + fun foo(a: Int?, other: Int) { + if (a != null && other % 2 == 0) { + println(a + 5) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).hasSize(1) + } + + it("does not report when the parameter is checked on non-nullity with an else statement") { + val code = """ + fun foo(a: Int?) { + if (a != null) { + println(a + 5) + } else { + println(5) + } + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + + it("does not report when there are other expressions after the non-null check") { + val code = """ + fun foo(a: Int?) { + if (a != null) { + println(a + 5) + } + val b = 5 + 2 + } + """.trimIndent() + assertThat(subject.compileAndLintWithContext(env, code)).isEmpty() + } + } + } } })